Import libtextclassifier
Test: atest TextClassifierServiceTest
Change-Id: Ief715193072d0af3aea230c3c9475ef18e8ac84c
diff --git a/java/Android.bp b/java/Android.bp
index f457e9e..f6cc0ed 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -47,7 +47,8 @@
"androidx.annotation_annotation",
"guava",
"androidx.room_room-runtime",
- "textclassifier-statsd"
+ "textclassifier-statsd",
+ "error_prone_annotations",
],
sdk_version: "system_current",
min_sdk_version: "28",
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml
index d4be7c7..98308db 100644
--- a/java/AndroidManifest.xml
+++ b/java/AndroidManifest.xml
@@ -25,14 +25,15 @@
<manifest xmlns:android="http://schemas.android.com/apk/res/android"
package="com.android.textclassifier"
android:versionCode="1"
- android:versionName="R-initial">
+ android:versionName="1.0.0">
- <uses-sdk android:minSdkVersion="28" android:targetSdkVersion="28"/>
+ <uses-sdk android:minSdkVersion="29" android:targetSdkVersion="29"/>
- <uses-permission android:name="android.permission.ACCESS_FINE_LOCATION" />
+ <uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
<application android:label="@string/app_name"
- android:icon="@drawable/app_icon">
+ android:icon="@drawable/app_icon"
+ android:extractNativeLibs="false">
<service
android:exported="true"
android:name=".DefaultTextClassifierService"
diff --git a/java/checkstyle.sh b/java/checkstyle.sh
deleted file mode 100755
index 39443d5..0000000
--- a/java/checkstyle.sh
+++ /dev/null
@@ -1,2 +0,0 @@
-cd ${ANDROID_BUILD_TOP}/vendor/google_experimental
-${ANDROID_BUILD_TOP}/prebuilts/checkstyle/checkstyle.py
\ No newline at end of file
diff --git a/java/google-java-format.sh b/java/google-java-format.sh
deleted file mode 100755
index 867cf9c..0000000
--- a/java/google-java-format.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-find src -name "*.java" -exec google-java-format -r -a {} \;
-find tests/unittests/src -name "*.java" -exec google-java-format -r -a {} \;
-find tests/robotests/src -name "*.java" -exec google-java-format -r -a {} \;
\ No newline at end of file
diff --git a/java/res/values/strings.xml b/java/res/values/strings.xml
index 60d1859..f8e410a 100644
--- a/java/res/values/strings.xml
+++ b/java/res/values/strings.xml
@@ -1,5 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<resources xmlns:xliff="urn:oasis:names:tc:xliff:document:1.2">
+ <!-- Label for this app [CHAR LIMIT=30] -->
<string name="app_name">Text classifier</string>
<!-- Label for item in the text selection menu to trigger an Email app. Should be a verb. [CHAR LIMIT=30] -->
<string name="email">Email</string>
diff --git a/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java b/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java
index 8d122ad..a1bf109 100644
--- a/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java
+++ b/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,202 +11,191 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
+
package com.android.textclassifier;
-import android.content.ContentResolver;
import android.content.Context;
import android.database.ContentObserver;
import android.provider.Settings;
-
import androidx.annotation.GuardedBy;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
import java.lang.ref.WeakReference;
import java.util.Objects;
import java.util.function.Supplier;
+import javax.annotation.Nullable;
/** Parses the {@link Settings.Global#TEXT_CLASSIFIER_ACTION_MODEL_PARAMS} flag. */
+// TODO(tonymak): Re-enable this class.
public final class ActionsModelParamsSupplier
- implements Supplier<ActionsModelParamsSupplier.ActionsModelParams> {
- private static final String TAG = "ActionsModelParamsSupplier";
+ implements Supplier<ActionsModelParamsSupplier.ActionsModelParams> {
+ private static final String TAG = "ActionsModelParams";
- @VisibleForTesting static final String KEY_REQUIRED_MODEL_VERSION = "required_model_version";
- @VisibleForTesting static final String KEY_REQUIRED_LOCALES = "required_locales";
+ @VisibleForTesting static final String KEY_REQUIRED_MODEL_VERSION = "required_model_version";
+ @VisibleForTesting static final String KEY_REQUIRED_LOCALES = "required_locales";
- @VisibleForTesting
- static final String KEY_SERIALIZED_PRECONDITIONS = "serialized_preconditions";
+ @VisibleForTesting static final String KEY_SERIALIZED_PRECONDITIONS = "serialized_preconditions";
- private final Context mAppContext;
- private final SettingsObserver mSettingsObserver;
+ private final Context appContext;
+ private final SettingsObserver settingsObserver;
- private final Object mLock = new Object();
- private final Runnable mOnChangedListener;
+ private final Object lock = new Object();
+ private final Runnable onChangedListener;
- @Nullable
- @GuardedBy("mLock")
- private ActionsModelParams mActionsModelParams;
+ @Nullable
+ @GuardedBy("lock")
+ private ActionsModelParams actionsModelParams;
- @GuardedBy("mLock")
- private boolean mParsed = true;
+ @GuardedBy("lock")
+ private boolean parsed = true;
- public ActionsModelParamsSupplier(Context context, @Nullable Runnable onChangedListener) {
- mAppContext = Preconditions.checkNotNull(context).getApplicationContext();
- mOnChangedListener = onChangedListener == null ? () -> {} : onChangedListener;
- mSettingsObserver =
- new SettingsObserver(
- mAppContext,
- () -> {
- synchronized (mLock) {
- TcLog.v(
- TAG,
- "Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS is"
- + " updated");
- mParsed = true;
- mOnChangedListener.run();
- }
- });
+ public ActionsModelParamsSupplier(Context context, @Nullable Runnable onChangedListener) {
+ appContext = Preconditions.checkNotNull(context).getApplicationContext();
+ this.onChangedListener = onChangedListener == null ? () -> {} : onChangedListener;
+ settingsObserver =
+ new SettingsObserver(
+ appContext,
+ () -> {
+ synchronized (lock) {
+ TcLog.v(TAG, "Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS is updated");
+ parsed = true;
+ this.onChangedListener.run();
+ }
+ });
+ }
+
+ /**
+ * Returns the parsed actions params or {@link ActionsModelParams#INVALID} if the value is
+ * invalid.
+ */
+ @Override
+ public ActionsModelParams get() {
+ synchronized (lock) {
+ if (parsed) {
+ actionsModelParams = parse();
+ parsed = false;
+ }
+ return actionsModelParams;
+ }
+ }
+
+ private static ActionsModelParams parse() {
+ // String settingStr = Settings.Global.getString(contentResolver,
+ // Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS);
+ // if (TextUtils.isEmpty(settingStr)) {
+ // return ActionsModelParams.INVALID;
+ // }
+ // try {
+ // KeyValueListParser keyValueListParser = new KeyValueListParser(',');
+ // keyValueListParser.setString(settingStr);
+ // int version = keyValueListParser.getInt(KEY_REQUIRED_MODEL_VERSION, -1);
+ // if (version == -1) {
+ // TcLog.w(TAG, "ActionsModelParams.Parse, invalid model version");
+ // return ActionsModelParams.INVALID;
+ // }
+ // String locales = keyValueListParser.getString(KEY_REQUIRED_LOCALES, null);
+ // if (locales == null) {
+ // TcLog.w(TAG, "ActionsModelParams.Parse, invalid locales");
+ // return ActionsModelParams.INVALID;
+ // }
+ // String serializedPreconditionsStr =
+ // keyValueListParser.getString(KEY_SERIALIZED_PRECONDITIONS, null);
+ // if (serializedPreconditionsStr == null) {
+ // TcLog.w(TAG, "ActionsModelParams.Parse, invalid preconditions");
+ // return ActionsModelParams.INVALID;
+ // }
+ // byte[] serializedPreconditions =
+ // Base64.decode(serializedPreconditionsStr, Base64.NO_WRAP);
+ // return new ActionsModelParams(version, locales, serializedPreconditions);
+ // } catch (Throwable t) {
+ // TcLog.e(TAG, "Invalid TEXT_CLASSIFIER_ACTION_MODEL_PARAMS, ignore", t);
+ // }
+ return ActionsModelParams.INVALID;
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ appContext.getContentResolver().unregisterContentObserver(settingsObserver);
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Represents the parsed result. */
+ public static final class ActionsModelParams {
+
+ public static final ActionsModelParams INVALID = new ActionsModelParams(-1, "", new byte[0]);
+
+ /** The required model version to apply {@code serializedPreconditions}. */
+ private final int requiredModelVersion;
+
+ /** The required model locales to apply {@code serializedPreconditions}. */
+ private final String requiredModelLocales;
+
+ /**
+ * The serialized params that will be applied to the model file, if all requirements are met. Do
+ * not modify.
+ */
+ private final byte[] serializedPreconditions;
+
+ public ActionsModelParams(
+ int requiredModelVersion, String requiredModelLocales, byte[] serializedPreconditions) {
+ this.requiredModelVersion = requiredModelVersion;
+ this.requiredModelLocales = Preconditions.checkNotNull(requiredModelLocales);
+ this.serializedPreconditions = Preconditions.checkNotNull(serializedPreconditions);
}
/**
- * Returns the parsed actions params or {@link ActionsModelParams#INVALID} if the value is
- * invalid.
+ * Returns the serialized preconditions. Returns {@code null} if the model in use does not meet
+ * all the requirements listed in the {@code ActionsModelParams} or the params are invalid.
*/
- @Override
- public ActionsModelParams get() {
- synchronized (mLock) {
- if (mParsed) {
- mActionsModelParams = parse(mAppContext.getContentResolver());
- mParsed = false;
- }
- return mActionsModelParams;
- }
+ @Nullable
+ public byte[] getSerializedPreconditions(ModelFileManager.ModelFile modelInUse) {
+ if (this == INVALID) {
+ return null;
+ }
+ if (modelInUse.getVersion() != requiredModelVersion) {
+ TcLog.w(
+ TAG,
+ String.format(
+ "Not applying mSerializedPreconditions, required version=%d, actual=%d",
+ requiredModelVersion, modelInUse.getVersion()));
+ return null;
+ }
+ if (!Objects.equals(modelInUse.getSupportedLocalesStr(), requiredModelLocales)) {
+ TcLog.w(
+ TAG,
+ String.format(
+ "Not applying mSerializedPreconditions, required locales=%s, actual=%s",
+ requiredModelLocales, modelInUse.getSupportedLocalesStr()));
+ return null;
+ }
+ return serializedPreconditions;
}
+ }
- private ActionsModelParams parse(ContentResolver contentResolver) {
- // String settingStr = Settings.Global.getString(contentResolver,
- // Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS);
- // if (TextUtils.isEmpty(settingStr)) {
- // return ActionsModelParams.INVALID;
- // }
- // try {
- // KeyValueListParser keyValueListParser = new KeyValueListParser(',');
- // keyValueListParser.setString(settingStr);
- // int version = keyValueListParser.getInt(KEY_REQUIRED_MODEL_VERSION, -1);
- // if (version == -1) {
- // TcLog.w(TAG, "ActionsModelParams.Parse, invalid model version");
- // return ActionsModelParams.INVALID;
- // }
- // String locales = keyValueListParser.getString(KEY_REQUIRED_LOCALES, null);
- // if (locales == null) {
- // TcLog.w(TAG, "ActionsModelParams.Parse, invalid locales");
- // return ActionsModelParams.INVALID;
- // }
- // String serializedPreconditionsStr =
- // keyValueListParser.getString(KEY_SERIALIZED_PRECONDITIONS, null);
- // if (serializedPreconditionsStr == null) {
- // TcLog.w(TAG, "ActionsModelParams.Parse, invalid preconditions");
- // return ActionsModelParams.INVALID;
- // }
- // byte[] serializedPreconditions =
- // Base64.decode(serializedPreconditionsStr, Base64.NO_WRAP);
- // return new ActionsModelParams(version, locales, serializedPreconditions);
- // } catch (Throwable t) {
- // TcLog.e(TAG, "Invalid TEXT_CLASSIFIER_ACTION_MODEL_PARAMS, ignore", t);
- // }
- return ActionsModelParams.INVALID;
+ private static final class SettingsObserver extends ContentObserver {
+
+ private final WeakReference<Runnable> onChangedListener;
+
+ SettingsObserver(Context appContext, Runnable listener) {
+ super(null);
+ onChangedListener = new WeakReference<>(listener);
+ // appContext.getContentResolver().registerContentObserver(
+ //
+ // Settings.Global.getUriFor(Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS),
+ // false /* notifyForDescendants */,
+ // this);
}
@Override
- protected void finalize() throws Throwable {
- try {
- mAppContext.getContentResolver().unregisterContentObserver(mSettingsObserver);
- } finally {
- super.finalize();
- }
+ public void onChange(boolean selfChange) {
+ if (onChangedListener.get() != null) {
+ onChangedListener.get().run();
+ }
}
-
- /** Represents the parsed result. */
- public static final class ActionsModelParams {
-
- public static final ActionsModelParams INVALID =
- new ActionsModelParams(-1, "", new byte[0]);
-
- /** The required model version to apply {@code mSerializedPreconditions}. */
- private final int mRequiredModelVersion;
-
- /** The required model locales to apply {@code mSerializedPreconditions}. */
- private final String mRequiredModelLocales;
-
- /**
- * The serialized params that will be applied to the model file, if all requirements are
- * met. Do not modify.
- */
- private final byte[] mSerializedPreconditions;
-
- public ActionsModelParams(
- int requiredModelVersion,
- String requiredModelLocales,
- byte[] serializedPreconditions) {
- mRequiredModelVersion = requiredModelVersion;
- mRequiredModelLocales = Preconditions.checkNotNull(requiredModelLocales);
- mSerializedPreconditions = Preconditions.checkNotNull(serializedPreconditions);
- }
-
- /**
- * Returns the serialized preconditions. Returns {@code null} if the the model in use does
- * not meet all the requirements listed in the {@code ActionsModelParams} or the params are
- * invalid.
- */
- @Nullable
- public byte[] getSerializedPreconditions(ModelFileManager.ModelFile modelInUse) {
- if (this == INVALID) {
- return null;
- }
- if (modelInUse.getVersion() != mRequiredModelVersion) {
- TcLog.w(
- TAG,
- String.format(
- "Not applying mSerializedPreconditions, required version=%d,"
- + " actual=%d",
- mRequiredModelVersion, modelInUse.getVersion()));
- return null;
- }
- if (!Objects.equals(modelInUse.getSupportedLocalesStr(), mRequiredModelLocales)) {
- TcLog.w(
- TAG,
- String.format(
- "Not applying mSerializedPreconditions, required locales=%s,"
- + " actual=%s",
- mRequiredModelLocales, modelInUse.getSupportedLocalesStr()));
- return null;
- }
- return mSerializedPreconditions;
- }
- }
-
- private static final class SettingsObserver extends ContentObserver {
-
- private final WeakReference<Runnable> mOnChangedListener;
-
- SettingsObserver(Context appContext, Runnable listener) {
- super(null);
- mOnChangedListener = new WeakReference<>(listener);
- // appContext.getContentResolver().registerContentObserver(
- //
- // Settings.Global.getUriFor(Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS),
- // false /* notifyForDescendants */,
- // this);
- }
-
- @Override
- public void onChange(boolean selfChange) {
- if (mOnChangedListener.get() != null) {
- mOnChangedListener.get().run();
- }
- }
- }
+ }
}
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index 41f0449..55ee402 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,7 +11,7 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier;
@@ -26,17 +26,11 @@
import android.util.Pair;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
-
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-
import com.android.textclassifier.intent.LabeledIntent;
import com.android.textclassifier.intent.TemplateIntentFactory;
import com.android.textclassifier.logging.ResultIdUtils;
-
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.RemoteActionTemplate;
-
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
@@ -46,182 +40,178 @@
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
/** Helper class for action suggestions. */
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
-public final class ActionsSuggestionsHelper {
- private static final String TAG = "ActionsSuggestions";
- private static final int USER_LOCAL = 0;
- private static final int FIRST_NON_LOCAL_USER = 1;
+final class ActionsSuggestionsHelper {
+ private static final String TAG = "ActionsSuggestions";
+ private static final int USER_LOCAL = 0;
+ private static final int FIRST_NON_LOCAL_USER = 1;
- private ActionsSuggestionsHelper() {}
+ private ActionsSuggestionsHelper() {}
- /**
- * Converts the messages to a list of native messages object that the model can understand.
- *
- * <p>User id encoding - local user is represented as 0, Other users are numbered according to
- * how far before they spoke last time in the conversation. For example, considering this
- * conversation:
- *
- * <ul>
- * <li>User A: xxx
- * <li>Local user: yyy
- * <li>User B: zzz
- * </ul>
- *
- * User A will be encoded as 2, user B will be encoded as 1 and local user will be encoded as 0.
- */
- public static ActionsSuggestionsModel.ConversationMessage[] toNativeMessages(
- List<ConversationActions.Message> messages,
- Function<CharSequence, List<String>> languageDetector) {
- List<ConversationActions.Message> messagesWithText =
- messages.stream()
- .filter(message -> !TextUtils.isEmpty(message.getText()))
- .collect(Collectors.toCollection(ArrayList::new));
- if (messagesWithText.isEmpty()) {
- return new ActionsSuggestionsModel.ConversationMessage[0];
- }
- Deque<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayDeque<>();
- PersonEncoder personEncoder = new PersonEncoder();
- int size = messagesWithText.size();
- for (int i = size - 1; i >= 0; i--) {
- ConversationActions.Message message = messagesWithText.get(i);
- long referenceTime =
- message.getReferenceTime() == null
- ? 0
- : message.getReferenceTime().toInstant().toEpochMilli();
- String timeZone =
- message.getReferenceTime() == null
- ? null
- : message.getReferenceTime().getZone().getId();
- nativeMessages.push(
- new ActionsSuggestionsModel.ConversationMessage(
- personEncoder.encode(message.getAuthor()),
- message.getText().toString(),
- referenceTime,
- timeZone,
- String.join(",", languageDetector.apply(message.getText()))));
- }
- return nativeMessages.toArray(
- new ActionsSuggestionsModel.ConversationMessage[nativeMessages.size()]);
+ /**
+ * Converts the messages to a list of native messages object that the model can understand.
+ *
+ * <p>User id encoding - local user is represented as 0, Other users are numbered according to how
+ * far before they spoke last time in the conversation. For example, considering this
+ * conversation:
+ *
+ * <ul>
+ * <li>User A: xxx
+ * <li>Local user: yyy
+ * <li>User B: zzz
+ * </ul>
+ *
+ * User A will be encoded as 2, user B will be encoded as 1 and local user will be encoded as 0.
+ */
+ public static ActionsSuggestionsModel.ConversationMessage[] toNativeMessages(
+ List<ConversationActions.Message> messages,
+ Function<CharSequence, List<String>> languageDetector) {
+ List<ConversationActions.Message> messagesWithText =
+ messages.stream()
+ .filter(message -> !TextUtils.isEmpty(message.getText()))
+ .collect(Collectors.toCollection(ArrayList::new));
+ if (messagesWithText.isEmpty()) {
+ return new ActionsSuggestionsModel.ConversationMessage[0];
}
-
- /** Returns the result id for logging. */
- public static String createResultId(
- Context context,
- List<ConversationActions.Message> messages,
- int modelVersion,
- List<Locale> modelLocales) {
- final int hash =
- Objects.hash(
- messages.stream().mapToInt(ActionsSuggestionsHelper::hashMessage),
- context.getPackageName(),
- System.currentTimeMillis());
- return ResultIdUtils.createId(modelVersion, modelLocales, hash);
+ Deque<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayDeque<>();
+ PersonEncoder personEncoder = new PersonEncoder();
+ int size = messagesWithText.size();
+ for (int i = size - 1; i >= 0; i--) {
+ ConversationActions.Message message = messagesWithText.get(i);
+ long referenceTime =
+ message.getReferenceTime() == null
+ ? 0
+ : message.getReferenceTime().toInstant().toEpochMilli();
+ String timeZone =
+ message.getReferenceTime() == null ? null : message.getReferenceTime().getZone().getId();
+ nativeMessages.push(
+ new ActionsSuggestionsModel.ConversationMessage(
+ personEncoder.encode(message.getAuthor()),
+ message.getText().toString(),
+ referenceTime,
+ timeZone,
+ String.join(",", languageDetector.apply(message.getText()))));
}
+ return nativeMessages.toArray(
+ new ActionsSuggestionsModel.ConversationMessage[nativeMessages.size()]);
+ }
- /** Generated labeled intent from an action suggestion and return the resolved result. */
- @Nullable
- public static LabeledIntent.Result createLabeledIntentResult(
- Context context,
- TemplateIntentFactory templateIntentFactory,
- ActionsSuggestionsModel.ActionSuggestion nativeSuggestion) {
- RemoteActionTemplate[] remoteActionTemplates = nativeSuggestion.getRemoteActionTemplates();
- if (remoteActionTemplates == null) {
- TcLog.w(
- TAG,
- "createRemoteAction: Missing template for type "
- + nativeSuggestion.getActionType());
- return null;
- }
- List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
- if (labeledIntents.isEmpty()) {
- return null;
- }
- // Given that we only support implicit intent here, we should expect there is just one
- // intent for each action type.
- LabeledIntent.TitleChooser titleChooser =
- ActionsSuggestionsHelper.createTitleChooser(nativeSuggestion.getActionType());
- return labeledIntents.get(0).resolve(context, titleChooser, null);
+ /** Returns the result id for logging. */
+ public static String createResultId(
+ Context context,
+ List<ConversationActions.Message> messages,
+ int modelVersion,
+ List<Locale> modelLocales) {
+ final int hash =
+ Objects.hash(
+ messages.stream().mapToInt(ActionsSuggestionsHelper::hashMessage),
+ context.getPackageName(),
+ System.currentTimeMillis());
+ return ResultIdUtils.createId(modelVersion, modelLocales, hash);
+ }
+
+ /** Generated labeled intent from an action suggestion and return the resolved result. */
+ @Nullable
+ public static LabeledIntent.Result createLabeledIntentResult(
+ Context context,
+ TemplateIntentFactory templateIntentFactory,
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion) {
+ RemoteActionTemplate[] remoteActionTemplates = nativeSuggestion.getRemoteActionTemplates();
+ if (remoteActionTemplates == null) {
+ TcLog.w(
+ TAG, "createRemoteAction: Missing template for type " + nativeSuggestion.getActionType());
+ return null;
}
-
- /** Returns a {@link LabeledIntent.TitleChooser} for conversation actions use case. */
- @Nullable
- public static LabeledIntent.TitleChooser createTitleChooser(String actionType) {
- if (ConversationAction.TYPE_OPEN_URL.equals(actionType)) {
- return (labeledIntent, resolveInfo) -> {
- if (resolveInfo.handleAllWebDataURI) {
- return labeledIntent.titleWithEntity;
- }
- if ("android".equals(resolveInfo.activityInfo.packageName)) {
- return labeledIntent.titleWithEntity;
- }
- return labeledIntent.titleWithoutEntity;
- };
- }
- return null;
+ List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
+ if (labeledIntents.isEmpty()) {
+ return null;
}
+ // Given that we only support implicit intent here, we should expect there is just one
+ // intent for each action type.
+ LabeledIntent.TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(nativeSuggestion.getActionType());
+ return labeledIntents.get(0).resolve(context, titleChooser, null);
+ }
- /**
- * Returns a list of {@link ConversationAction}s that have 0 duplicates. Two actions are
- * duplicates if they may look the same to users. This function assumes every
- * ConversationActions with a non-null RemoteAction also have a non-null intent in the extras.
- */
- public static List<ConversationAction> removeActionsWithDuplicates(
- List<ConversationAction> conversationActions) {
- // Ideally, we should compare title and icon here, but comparing icon is expensive and thus
- // we use the component name of the target handler as the heuristic.
- Map<Pair<String, String>, Integer> counter = new ArrayMap<>();
- for (ConversationAction conversationAction : conversationActions) {
- Pair<String, String> representation = getRepresentation(conversationAction);
- if (representation == null) {
- continue;
- }
- Integer existingCount = counter.getOrDefault(representation, 0);
- counter.put(representation, existingCount + 1);
+ /** Returns a {@link LabeledIntent.TitleChooser} for conversation actions use case. */
+ @Nullable
+ public static LabeledIntent.TitleChooser createTitleChooser(String actionType) {
+ if (ConversationAction.TYPE_OPEN_URL.equals(actionType)) {
+ return (labeledIntent, resolveInfo) -> {
+ if (resolveInfo.handleAllWebDataURI) {
+ return labeledIntent.titleWithEntity;
}
- List<ConversationAction> result = new ArrayList<>();
- for (ConversationAction conversationAction : conversationActions) {
- Pair<String, String> representation = getRepresentation(conversationAction);
- if (representation == null || counter.getOrDefault(representation, 0) == 1) {
- result.add(conversationAction);
- }
+ if ("android".equals(resolveInfo.activityInfo.packageName)) {
+ return labeledIntent.titleWithEntity;
}
- return result;
+ return labeledIntent.titleWithoutEntity;
+ };
}
+ return null;
+ }
- @Nullable
- private static Pair<String, String> getRepresentation(ConversationAction conversationAction) {
- RemoteAction remoteAction = conversationAction.getAction();
- if (remoteAction == null) {
- return null;
- }
- Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
- ComponentName componentName = actionIntent.getComponent();
- // Action without a component name will be considered as from the same app.
- String packageName = componentName == null ? null : componentName.getPackageName();
- return new Pair<>(conversationAction.getAction().getTitle().toString(), packageName);
+ /**
+ * Returns a list of {@link ConversationAction}s that have 0 duplicates. Two actions are
+ * duplicates if they may look the same to users. This function assumes every ConversationActions
+ * with a non-null RemoteAction also have a non-null intent in the extras.
+ */
+ public static List<ConversationAction> removeActionsWithDuplicates(
+ List<ConversationAction> conversationActions) {
+ // Ideally, we should compare title and icon here, but comparing icon is expensive and thus
+ // we use the component name of the target handler as the heuristic.
+ Map<Pair<String, String>, Integer> counter = new ArrayMap<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ Pair<String, String> representation = getRepresentation(conversationAction);
+ if (representation == null) {
+ continue;
+ }
+ Integer existingCount = counter.getOrDefault(representation, 0);
+ counter.put(representation, existingCount + 1);
}
-
- private static final class PersonEncoder {
- private final Map<Person, Integer> mMapping = new ArrayMap<>();
- private int mNextUserId = FIRST_NON_LOCAL_USER;
-
- private int encode(Person person) {
- if (ConversationActions.Message.PERSON_USER_SELF.equals(person)) {
- return USER_LOCAL;
- }
- Integer result = mMapping.get(person);
- if (result == null) {
- mMapping.put(person, mNextUserId);
- result = mNextUserId;
- mNextUserId++;
- }
- return result;
- }
+ List<ConversationAction> result = new ArrayList<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ Pair<String, String> representation = getRepresentation(conversationAction);
+ if (representation == null || counter.getOrDefault(representation, 0) == 1) {
+ result.add(conversationAction);
+ }
}
+ return result;
+ }
- private static int hashMessage(ConversationActions.Message message) {
- return Objects.hash(message.getAuthor(), message.getText(), message.getReferenceTime());
+ @Nullable
+ private static Pair<String, String> getRepresentation(ConversationAction conversationAction) {
+ RemoteAction remoteAction = conversationAction.getAction();
+ if (remoteAction == null) {
+ return null;
}
+ Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
+ ComponentName componentName = actionIntent.getComponent();
+ // Action without a component name will be considered as from the same app.
+ String packageName = componentName == null ? null : componentName.getPackageName();
+ return new Pair<>(conversationAction.getAction().getTitle().toString(), packageName);
+ }
+
+ private static final class PersonEncoder {
+ private final Map<Person, Integer> personToUserIdMap = new ArrayMap<>();
+ private int nextUserId = FIRST_NON_LOCAL_USER;
+
+ private int encode(Person person) {
+ if (ConversationActions.Message.PERSON_USER_SELF.equals(person)) {
+ return USER_LOCAL;
+ }
+ Integer result = personToUserIdMap.get(person);
+ if (result == null) {
+ personToUserIdMap.put(person, nextUserId);
+ result = nextUserId;
+ nextUserId++;
+ }
+ return result;
+ }
+ }
+
+ private static int hashMessage(ConversationActions.Message message) {
+ return Objects.hash(message.getAuthor(), message.getText(), message.getReferenceTime());
+ }
}
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index e842247..32cba2f 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,7 +11,7 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier;
@@ -26,160 +26,152 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
-
import com.android.textclassifier.utils.IndentingPrintWriter;
-
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
-
import java.io.FileDescriptor;
import java.io.PrintWriter;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
+/** An implementation of a TextClassifierService. */
public final class DefaultTextClassifierService extends TextClassifierService {
- private static final String TAG = "default_tcs";
+ private static final String TAG = "default_tcs";
- // TODO: Figure out do we need more concurrency.
- private final ListeningExecutorService mNormPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newFixedThreadPool(
- /* nThreads= */ 2,
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-norm-prio-executor")
- .setPriority(Thread.NORM_PRIORITY)
- .build()));
+ // TODO: Figure out do we need more concurrency.
+ private final ListeningExecutorService normPriorityExecutor =
+ MoreExecutors.listeningDecorator(
+ Executors.newFixedThreadPool(
+ /* nThreads= */ 2,
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-norm-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY)
+ .build()));
- private final ListeningExecutorService mLowPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newSingleThreadExecutor(
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-low-prio-executor")
- .setPriority(Thread.NORM_PRIORITY - 1)
- .build()));
+ private final ListeningExecutorService lowPriorityExecutor =
+ MoreExecutors.listeningDecorator(
+ Executors.newSingleThreadExecutor(
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-low-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY - 1)
+ .build()));
- private TextClassifierImpl mTextClassifier;
+ private TextClassifierImpl textClassifier;
- @Override
- public void onCreate() {
- super.onCreate();
- mTextClassifier = new TextClassifierImpl(this, new TextClassificationConstants());
- }
+ @Override
+ public void onCreate() {
+ super.onCreate();
+ textClassifier = new TextClassifierImpl(this, new TextClassificationConstants());
+ }
- @Override
- public void onSuggestSelection(
- TextClassificationSessionId sessionId,
- TextSelection.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextSelection> callback) {
- handleRequestAsync(
- () -> mTextClassifier.suggestSelection(request), callback, cancellationSignal);
- }
+ @Override
+ public void onSuggestSelection(
+ TextClassificationSessionId sessionId,
+ TextSelection.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextSelection> callback) {
+ handleRequestAsync(
+ () -> textClassifier.suggestSelection(request), callback, cancellationSignal);
+ }
- @Override
- public void onClassifyText(
- TextClassificationSessionId sessionId,
- TextClassification.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextClassification> callback) {
- handleRequestAsync(
- () -> mTextClassifier.classifyText(request), callback, cancellationSignal);
- }
+ @Override
+ public void onClassifyText(
+ TextClassificationSessionId sessionId,
+ TextClassification.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextClassification> callback) {
+ handleRequestAsync(() -> textClassifier.classifyText(request), callback, cancellationSignal);
+ }
- @Override
- public void onGenerateLinks(
- TextClassificationSessionId sessionId,
- TextLinks.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextLinks> callback) {
- handleRequestAsync(
- () -> mTextClassifier.generateLinks(request), callback, cancellationSignal);
- }
+ @Override
+ public void onGenerateLinks(
+ TextClassificationSessionId sessionId,
+ TextLinks.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextLinks> callback) {
+ handleRequestAsync(() -> textClassifier.generateLinks(request), callback, cancellationSignal);
+ }
- @Override
- public void onSuggestConversationActions(
- TextClassificationSessionId sessionId,
- ConversationActions.Request request,
- CancellationSignal cancellationSignal,
- Callback<ConversationActions> callback) {
- handleRequestAsync(
- () -> mTextClassifier.suggestConversationActions(request),
- callback,
- cancellationSignal);
- }
+ @Override
+ public void onSuggestConversationActions(
+ TextClassificationSessionId sessionId,
+ ConversationActions.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<ConversationActions> callback) {
+ handleRequestAsync(
+ () -> textClassifier.suggestConversationActions(request), callback, cancellationSignal);
+ }
- @Override
- public void onDetectLanguage(
- TextClassificationSessionId sessionId,
- TextLanguage.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextLanguage> callback) {
- handleRequestAsync(
- () -> mTextClassifier.detectLanguage(request), callback, cancellationSignal);
- }
+ @Override
+ public void onDetectLanguage(
+ TextClassificationSessionId sessionId,
+ TextLanguage.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextLanguage> callback) {
+ handleRequestAsync(() -> textClassifier.detectLanguage(request), callback, cancellationSignal);
+ }
- @Override
- public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
- handleEvent(() -> mTextClassifier.onSelectionEvent(event));
- }
+ @Override
+ public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
+ handleEvent(() -> textClassifier.onSelectionEvent(event));
+ }
- @Override
- public void onTextClassifierEvent(
- TextClassificationSessionId sessionId, TextClassifierEvent event) {
- handleEvent(() -> mTextClassifier.onTextClassifierEvent(sessionId, event));
- }
+ @Override
+ public void onTextClassifierEvent(
+ TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ handleEvent(() -> textClassifier.onTextClassifierEvent(sessionId, event));
+ }
- @Override
- protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
- IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
- mTextClassifier.dump(indentingPrintWriter);
- indentingPrintWriter.flush();
- }
+ @Override
+ protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
+ IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
+ textClassifier.dump(indentingPrintWriter);
+ indentingPrintWriter.flush();
+ }
- private <T> void handleRequestAsync(
- Callable<T> callable, Callback<T> callback, CancellationSignal cancellationSignal) {
- ListenableFuture<T> result = mNormPriorityExecutor.submit(callable);
- Futures.addCallback(
- result,
- new FutureCallback<T>() {
- @Override
- public void onSuccess(T result) {
- callback.onSuccess(result);
- }
+ private <T> void handleRequestAsync(
+ Callable<T> callable, Callback<T> callback, CancellationSignal cancellationSignal) {
+ ListenableFuture<T> result = normPriorityExecutor.submit(callable);
+ Futures.addCallback(
+ result,
+ new FutureCallback<T>() {
+ @Override
+ public void onSuccess(T result) {
+ callback.onSuccess(result);
+ }
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "onFailure: ", t);
- callback.onFailure(t.getMessage());
- }
- },
- MoreExecutors.directExecutor());
- cancellationSignal.setOnCancelListener(
- () -> result.cancel(/* mayInterruptIfRunning= */ true));
- }
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "onFailure: ", t);
+ callback.onFailure(t.getMessage());
+ }
+ },
+ MoreExecutors.directExecutor());
+ cancellationSignal.setOnCancelListener(() -> result.cancel(/* mayInterruptIfRunning= */ true));
+ }
- private void handleEvent(Runnable runnable) {
- ListenableFuture<Void> result =
- mLowPriorityExecutor.submit(
- () -> {
- runnable.run();
- return null;
- });
- Futures.addCallback(
- result,
- new FutureCallback<Void>() {
- @Override
- public void onSuccess(Void result) {}
+ private void handleEvent(Runnable runnable) {
+ ListenableFuture<Void> result =
+ lowPriorityExecutor.submit(
+ () -> {
+ runnable.run();
+ return null;
+ });
+ Futures.addCallback(
+ result,
+ new FutureCallback<Void>() {
+ @Override
+ public void onSuccess(Void result) {}
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "onFailure: ", t);
- }
- },
- MoreExecutors.directExecutor());
- }
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "onFailure: ", t);
+ }
+ },
+ MoreExecutors.directExecutor());
+ }
}
diff --git a/java/src/com/android/textclassifier/Entity.java b/java/src/com/android/textclassifier/Entity.java
index 8a29923..6410a3e 100644
--- a/java/src/com/android/textclassifier/Entity.java
+++ b/java/src/com/android/textclassifier/Entity.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,62 +17,65 @@
package com.android.textclassifier;
import androidx.annotation.FloatRange;
-
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
/** A representation of an identified entity with the confidence score */
public final class Entity implements Comparable<Entity> {
- private final String mEntityType;
- private float mScore;
+ private final String entityType;
+ private final float score;
- public Entity(String entityType, float score) {
- mEntityType = Preconditions.checkNotNull(entityType);
- mScore = score;
- }
+ public Entity(String entityType, float score) {
+ this.entityType = Preconditions.checkNotNull(entityType);
+ this.score = score;
+ }
- public String getEntityType() {
- return mEntityType;
- }
+ public String getEntityType() {
+ return entityType;
+ }
- /**
- * Returns the confidence score of the entity, which ranged from 0.0 (low confidence) to 1.0
- * (high confidence).
- */
- @FloatRange(from = 0.0, to = 1.0)
- public Float getScore() {
- return mScore;
- }
+ /**
+ * Returns the confidence score of the entity, which ranged from 0.0 (low confidence) to 1.0 (high
+ * confidence).
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public Float getScore() {
+ return score;
+ }
- @Override
- public int hashCode() {
- return Objects.hashCode(mEntityType, mScore);
- }
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(entityType, score);
+ }
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- Entity entity = (Entity) o;
- return Float.compare(entity.mScore, mScore) == 0
- && java.util.Objects.equals(mEntityType, entity.mEntityType);
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
}
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Entity entity = (Entity) o;
+ return Float.compare(entity.score, score) == 0
+ && java.util.Objects.equals(entityType, entity.entityType);
+ }
- @Override
- public String toString() {
- return "Entity{" + mEntityType + ": " + mScore + "}";
- }
+ @Override
+ public String toString() {
+ return "Entity{" + entityType + ": " + score + "}";
+ }
- @Override
- public int compareTo(Entity entity) {
- // This method is implemented for sorting Entity. Sort the entities by the confidence score
- // in descending order firstly. If the scores are the same, then sort them by the entity
- // type in ascending order.
- int result = Float.compare(entity.getScore(), mScore);
- if (result == 0) {
- return mEntityType.compareTo(entity.getEntityType());
- }
- return result;
+ @Override
+ public int compareTo(Entity entity) {
+ // This method is implemented for sorting Entity. Sort the entities by the confidence score
+ // in descending order firstly. If the scores are the same, then sort them by the entity
+ // type in ascending order.
+ int result = Float.compare(entity.getScore(), score);
+ if (result == 0) {
+ return entityType.compareTo(entity.getEntityType());
}
+ return result;
+ }
}
diff --git a/java/src/com/android/textclassifier/EntityConfidence.java b/java/src/com/android/textclassifier/EntityConfidence.java
index d3d26e4..ef8ff05 100644
--- a/java/src/com/android/textclassifier/EntityConfidence.java
+++ b/java/src/com/android/textclassifier/EntityConfidence.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,10 +17,8 @@
package com.android.textclassifier;
import androidx.annotation.FloatRange;
-import androidx.annotation.NonNull;
import androidx.collection.ArrayMap;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -29,65 +27,64 @@
/** Helper object for setting and getting entity scores for classified text. */
final class EntityConfidence {
- static final EntityConfidence EMPTY = new EntityConfidence(Collections.emptyMap());
+ static final EntityConfidence EMPTY = new EntityConfidence(Collections.emptyMap());
- private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>();
- private final ArrayList<String> mSortedEntities = new ArrayList<>();
+ private final ArrayMap<String, Float> entityConfidence = new ArrayMap<>();
+ private final ArrayList<String> sortedEntities = new ArrayList<>();
- /**
- * Constructs an EntityConfidence from a map of entity to confidence.
- *
- * <p>Map entries that have 0 confidence are removed, and values greater than 1 are clamped to
- * 1.
- *
- * @param source a map from entity to a confidence value in the range 0 (low confidence) to 1
- * (high confidence).
- */
- EntityConfidence(@NonNull Map<String, Float> source) {
- Preconditions.checkNotNull(source);
+ /**
+ * Constructs an EntityConfidence from a map of entity to confidence.
+ *
+ * <p>Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
+ *
+ * @param source a map from entity to a confidence value in the range 0 (low confidence) to 1
+ * (high confidence).
+ */
+ EntityConfidence(Map<String, Float> source) {
+ Preconditions.checkNotNull(source);
- // Prune non-existent entities and clamp to 1.
- mEntityConfidence.ensureCapacity(source.size());
- for (Map.Entry<String, Float> it : source.entrySet()) {
- if (it.getValue() <= 0) continue;
- mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
- }
- resetSortedEntitiesFromMap();
+ // Prune non-existent entities and clamp to 1.
+ entityConfidence.ensureCapacity(source.size());
+ for (Map.Entry<String, Float> it : source.entrySet()) {
+ if (it.getValue() <= 0) {
+ continue;
+ }
+ entityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
}
+ resetSortedEntitiesFromMap();
+ }
- /**
- * Returns an immutable list of entities found in the classified text ordered from high
- * confidence to low confidence.
- */
- @NonNull
- public List<String> getEntities() {
- return Collections.unmodifiableList(mSortedEntities);
- }
+ /**
+ * Returns an immutable list of entities found in the classified text ordered from high confidence
+ * to low confidence.
+ */
+ public List<String> getEntities() {
+ return Collections.unmodifiableList(sortedEntities);
+ }
- /**
- * Returns the confidence score for the specified entity. The value ranges from 0 (low
- * confidence) to 1 (high confidence). 0 indicates that the entity was not found for the
- * classified text.
- */
- @FloatRange(from = 0.0, to = 1.0)
- public float getConfidenceScore(String entity) {
- return mEntityConfidence.getOrDefault(entity, 0f);
- }
+ /**
+ * Returns the confidence score for the specified entity. The value ranges from 0 (low confidence)
+ * to 1 (high confidence). 0 indicates that the entity was not found for the classified text.
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public float getConfidenceScore(String entity) {
+ return entityConfidence.getOrDefault(entity, 0f);
+ }
- @Override
- public String toString() {
- return mEntityConfidence.toString();
- }
+ @Override
+ public String toString() {
+ return entityConfidence.toString();
+ }
- private void resetSortedEntitiesFromMap() {
- mSortedEntities.clear();
- mSortedEntities.ensureCapacity(mEntityConfidence.size());
- mSortedEntities.addAll(mEntityConfidence.keySet());
- mSortedEntities.sort(
- (e1, e2) -> {
- float score1 = mEntityConfidence.get(e1);
- float score2 = mEntityConfidence.get(e2);
- return Float.compare(score2, score1);
- });
- }
+ private void resetSortedEntitiesFromMap() {
+ sortedEntities.clear();
+ sortedEntities.ensureCapacity(entityConfidence.size());
+ sortedEntities.addAll(entityConfidence.keySet());
+ sortedEntities.sort(
+ (e1, e2) -> {
+ float score1 = entityConfidence.get(e1);
+ float score2 = entityConfidence.get(e2);
+ return Float.compare(score2, score1);
+ });
+ }
}
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java
index 8d18e19..2039d76 100644
--- a/java/src/com/android/textclassifier/ExtrasUtils.java
+++ b/java/src/com/android/textclassifier/ExtrasUtils.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,295 +23,287 @@
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
-
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-
import com.google.android.textclassifier.AnnotatorModel;
-
+import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
+import javax.annotation.Nullable;
/** Utility class for inserting and retrieving data in TextClassifier request/response extras. */
// TODO: Make this a TestApi for CTS testing.
public final class ExtrasUtils {
- // Keys for response objects.
- private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
- private static final String ENTITIES_EXTRAS = "entities-extras";
- private static final String ACTION_INTENT = "action-intent";
- private static final String ACTIONS_INTENTS = "actions-intents";
- private static final String FOREIGN_LANGUAGE = "foreign-language";
- private static final String ENTITY_TYPE = "entity-type";
- private static final String SCORE = "score";
- private static final String MODEL_VERSION = "model-version";
- private static final String MODEL_NAME = "model-name";
- private static final String TEXT_LANGUAGES = "text-languages";
- private static final String ENTITIES = "entities";
+ // Keys for response objects.
+ private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
+ private static final String ENTITIES_EXTRAS = "entities-extras";
+ private static final String ACTION_INTENT = "action-intent";
+ private static final String ACTIONS_INTENTS = "actions-intents";
+ private static final String FOREIGN_LANGUAGE = "foreign-language";
+ private static final String ENTITY_TYPE = "entity-type";
+ private static final String SCORE = "score";
+ private static final String MODEL_VERSION = "model-version";
+ private static final String MODEL_NAME = "model-name";
+ private static final String TEXT_LANGUAGES = "text-languages";
+ private static final String ENTITIES = "entities";
- // Keys for request objects.
- private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
- "is-serialized-entity-data-enabled";
+ // Keys for request objects.
+ private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
+ "is-serialized-entity-data-enabled";
- private ExtrasUtils() {}
+ private ExtrasUtils() {}
- /** Bundles and returns foreign language detection information for TextClassifier responses. */
- static Bundle createForeignLanguageExtra(String language, float score, int modelVersion) {
- final Bundle bundle = new Bundle();
- bundle.putString(ENTITY_TYPE, language);
- bundle.putFloat(SCORE, score);
- bundle.putInt(MODEL_VERSION, modelVersion);
- bundle.putString(MODEL_NAME, "langId_v" + modelVersion);
- return bundle;
+ /** Bundles and returns foreign language detection information for TextClassifier responses. */
+ static Bundle createForeignLanguageExtra(String language, float score, int modelVersion) {
+ final Bundle bundle = new Bundle();
+ bundle.putString(ENTITY_TYPE, language);
+ bundle.putFloat(SCORE, score);
+ bundle.putInt(MODEL_VERSION, modelVersion);
+ bundle.putString(MODEL_NAME, "langId_v" + modelVersion);
+ return bundle;
+ }
+
+ /**
+ * Stores {@code extra} as foreign language information in TextClassifier response object's extras
+ * {@code container}.
+ *
+ * @see #getForeignLanguageExtra(TextClassification)
+ */
+ static void putForeignLanguageExtra(Bundle container, Bundle extra) {
+ container.putParcelable(FOREIGN_LANGUAGE, extra);
+ }
+
+ /**
+ * Returns foreign language detection information contained in the TextClassification object.
+ * responses.
+ *
+ * @see #putForeignLanguageExtra(Bundle, Bundle)
+ */
+ @Nullable
+ @VisibleForTesting
+ public static Bundle getForeignLanguageExtra(@Nullable TextClassification classification) {
+ if (classification == null) {
+ return null;
}
+ return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
+ }
- /**
- * Stores {@code extra} as foreign language information in TextClassifier response object's
- * extras {@code container}.
- *
- * @see #getForeignLanguageExtra(TextClassification)
- */
- static void putForeignLanguageExtra(Bundle container, Bundle extra) {
- container.putParcelable(FOREIGN_LANGUAGE, extra);
+ /** @see #getTopLanguage(Intent) */
+ static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
+ final int maxSize = Math.min(3, languageScores.getEntities().size());
+ final String[] languages =
+ languageScores.getEntities().subList(0, maxSize).toArray(new String[0]);
+ final float[] scores = new float[languages.length];
+ for (int i = 0; i < languages.length; i++) {
+ scores[i] = languageScores.getConfidenceScore(languages[i]);
}
+ container.putStringArray(ENTITY_TYPE, languages);
+ container.putFloatArray(SCORE, scores);
+ }
- /**
- * Returns foreign language detection information contained in the TextClassification object.
- * responses.
- *
- * @see #putForeignLanguageExtra(Bundle, Bundle)
- */
- @Nullable
- @VisibleForTesting
- public static Bundle getForeignLanguageExtra(@Nullable TextClassification classification) {
- if (classification == null) {
- return null;
+ /** See {@link #putTopLanguageScores(Bundle, EntityConfidence)}. */
+ @Nullable
+ @VisibleForTesting
+ public static ULocale getTopLanguage(@Nullable Intent intent) {
+ if (intent == null) {
+ return null;
+ }
+ final Bundle tcBundle = intent.getBundleExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER);
+ if (tcBundle == null) {
+ return null;
+ }
+ final Bundle textLanguagesExtra = tcBundle.getBundle(TEXT_LANGUAGES);
+ if (textLanguagesExtra == null) {
+ return null;
+ }
+ final String[] languages = textLanguagesExtra.getStringArray(ENTITY_TYPE);
+ final float[] scores = textLanguagesExtra.getFloatArray(SCORE);
+ if (languages == null
+ || scores == null
+ || languages.length == 0
+ || languages.length != scores.length) {
+ return null;
+ }
+ int highestScoringIndex = 0;
+ for (int i = 1; i < languages.length; i++) {
+ if (scores[highestScoringIndex] < scores[i]) {
+ highestScoringIndex = i;
+ }
+ }
+ return ULocale.forLanguageTag(languages[highestScoringIndex]);
+ }
+
+ public static void putTextLanguagesExtra(Bundle container, Bundle extra) {
+ container.putBundle(TEXT_LANGUAGES, extra);
+ }
+
+ /**
+ * Stores {@code actionsIntents} information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ static void putActionsIntents(Bundle container, ArrayList<Intent> actionsIntents) {
+ container.putParcelableArrayList(ACTIONS_INTENTS, actionsIntents);
+ }
+
+ /**
+ * Stores {@code actionIntent} information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ public static void putActionIntent(Bundle container, @Nullable Intent actionIntent) {
+ container.putParcelable(ACTION_INTENT, actionIntent);
+ }
+
+ /** Returns {@code actionIntent} information contained in a TextClassifier response object. */
+ @Nullable
+ public static Intent getActionIntent(Bundle container) {
+ return container.getParcelable(ACTION_INTENT);
+ }
+
+ /**
+ * Stores serialized entity data information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ public static void putSerializedEntityData(
+ Bundle container, @Nullable byte[] serializedEntityData) {
+ container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
+ }
+
+ /** Returns serialized entity data information contained in a TextClassifier response object. */
+ @Nullable
+ public static byte[] getSerializedEntityData(Bundle container) {
+ return container.getByteArray(SERIALIZED_ENTITIES_DATA);
+ }
+
+ /**
+ * Stores {@code entities} information in TextClassifier response object's extras {@code
+ * container}.
+ *
+ * @see {@link #getCopyText(Bundle)}
+ */
+ public static void putEntitiesExtras(Bundle container, @Nullable Bundle entitiesExtras) {
+ container.putParcelable(ENTITIES_EXTRAS, entitiesExtras);
+ }
+
+ /**
+ * Returns {@code entities} information contained in a TextClassifier response object.
+ *
+ * @see {@link #putEntitiesExtras(Bundle, Bundle)}
+ */
+ @Nullable
+ public static String getCopyText(Bundle container) {
+ Bundle entitiesExtras = container.getParcelable(ENTITIES_EXTRAS);
+ if (entitiesExtras == null) {
+ return null;
+ }
+ return entitiesExtras.getString("text");
+ }
+
+ /** Returns {@code actionIntents} information contained in the TextClassification object. */
+ @Nullable
+ public static ArrayList<Intent> getActionsIntents(@Nullable TextClassification classification) {
+ if (classification == null) {
+ return null;
+ }
+ return classification.getExtras().getParcelableArrayList(ACTIONS_INTENTS);
+ }
+
+ /**
+ * Returns the first action found in the {@code classification} object with an intent action
+ * string, {@code intentAction}.
+ */
+ @Nullable
+ @VisibleForTesting
+ public static RemoteAction findAction(
+ @Nullable TextClassification classification, @Nullable String intentAction) {
+ if (classification == null || intentAction == null) {
+ return null;
+ }
+ final ArrayList<Intent> actionIntents = getActionsIntents(classification);
+ if (actionIntents != null) {
+ final int size = actionIntents.size();
+ for (int i = 0; i < size; i++) {
+ final Intent intent = actionIntents.get(i);
+ if (intent != null && intentAction.equals(intent.getAction())) {
+ return classification.getActions().get(i);
}
- return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
+ }
}
+ return null;
+ }
- /** @see #getTopLanguage(Intent) */
- @VisibleForTesting
- static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
- final int maxSize = Math.min(3, languageScores.getEntities().size());
- final String[] languages =
- languageScores.getEntities().subList(0, maxSize).toArray(new String[0]);
- final float[] scores = new float[languages.length];
- for (int i = 0; i < languages.length; i++) {
- scores[i] = languageScores.getConfidenceScore(languages[i]);
- }
- container.putStringArray(ENTITY_TYPE, languages);
- container.putFloatArray(SCORE, scores);
- }
+ /** Returns the first "translate" action found in the {@code classification} object. */
+ @Nullable
+ @VisibleForTesting
+ public static RemoteAction findTranslateAction(@Nullable TextClassification classification) {
+ return findAction(classification, Intent.ACTION_TRANSLATE);
+ }
- /** @see #putTopLanguageScores(Bundle, EntityConfidence) */
- @Nullable
- @VisibleForTesting
- public static ULocale getTopLanguage(@Nullable Intent intent) {
- if (intent == null) {
- return null;
- }
- final Bundle tcBundle = intent.getBundleExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER);
- if (tcBundle == null) {
- return null;
- }
- final Bundle textLanguagesExtra = tcBundle.getBundle(TEXT_LANGUAGES);
- if (textLanguagesExtra == null) {
- return null;
- }
- final String[] languages = textLanguagesExtra.getStringArray(ENTITY_TYPE);
- final float[] scores = textLanguagesExtra.getFloatArray(SCORE);
- if (languages == null
- || scores == null
- || languages.length == 0
- || languages.length != scores.length) {
- return null;
- }
- int highestScoringIndex = 0;
- for (int i = 1; i < languages.length; i++) {
- if (scores[highestScoringIndex] < scores[i]) {
- highestScoringIndex = i;
- }
- }
- return ULocale.forLanguageTag(languages[highestScoringIndex]);
+ /** Returns the entity type contained in the {@code extra}. */
+ @Nullable
+ @VisibleForTesting
+ public static String getEntityType(@Nullable Bundle extra) {
+ if (extra == null) {
+ return null;
}
+ return extra.getString(ENTITY_TYPE);
+ }
- public static void putTextLanguagesExtra(Bundle container, Bundle extra) {
- container.putBundle(TEXT_LANGUAGES, extra);
+ /** Returns the score contained in the {@code extra}. */
+ @VisibleForTesting
+ public static float getScore(Bundle extra) {
+ final int defaultValue = -1;
+ if (extra == null) {
+ return defaultValue;
}
+ return extra.getFloat(SCORE, defaultValue);
+ }
- /**
- * Stores {@code actionIntents} information in TextClassifier response object's extras {@code
- * container}.
- */
- static void putActionsIntents(Bundle container, ArrayList<Intent> actionsIntents) {
- container.putParcelableArrayList(ACTIONS_INTENTS, actionsIntents);
+ /** Returns the model name contained in the {@code extra}. */
+ @Nullable
+ public static String getModelName(@Nullable Bundle extra) {
+ if (extra == null) {
+ return null;
}
+ return extra.getString(MODEL_NAME);
+ }
- /**
- * Stores {@code actionIntents} information in TextClassifier response object's extras {@code
- * container}.
- */
- public static void putActionIntent(Bundle container, @Nullable Intent actionIntent) {
- container.putParcelable(ACTION_INTENT, actionIntent);
+ /** Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}. */
+ public static void putEntities(
+ Bundle container, @Nullable AnnotatorModel.ClassificationResult[] classifications) {
+ if (classifications == null || classifications.length == 0) {
+ return;
}
+ ArrayList<Bundle> entitiesBundle = new ArrayList<>();
+ for (AnnotatorModel.ClassificationResult classification : classifications) {
+ if (classification == null) {
+ continue;
+ }
+ Bundle entityBundle = new Bundle();
+ entityBundle.putString(ENTITY_TYPE, classification.getCollection());
+ entityBundle.putByteArray(SERIALIZED_ENTITIES_DATA, classification.getSerializedEntityData());
+ entitiesBundle.add(entityBundle);
+ }
+ if (!entitiesBundle.isEmpty()) {
+ container.putParcelableArrayList(ENTITIES, entitiesBundle);
+ }
+ }
- /** Returns {@code actionIntent} information contained in a TextClassifier response object. */
- @Nullable
- public static Intent getActionIntent(Bundle container) {
- return container.getParcelable(ACTION_INTENT);
- }
+ /** Returns a list of entities contained in the {@code extra}. */
+ @Nullable
+ @VisibleForTesting
+ public static List<Bundle> getEntities(Bundle container) {
+ return container.getParcelableArrayList(ENTITIES);
+ }
- /**
- * Stores serialized entity data information in TextClassifier response object's extras {@code
- * container}.
- */
- public static void putSerializedEntityData(
- Bundle container, @Nullable byte[] serializedEntityData) {
- container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
- }
+ /** Whether the annotator should populate serialized entity data into the result object. */
+ public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
+ return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
+ }
- /** Returns serialized entity data information contained in a TextClassifier response object. */
- @Nullable
- public static byte[] getSerializedEntityData(Bundle container) {
- return container.getByteArray(SERIALIZED_ENTITIES_DATA);
- }
-
- /**
- * Stores {@code entities} information in TextClassifier response object's extras {@code
- * container}.
- *
- * @see {@link #getCopyText(Bundle)}
- */
- public static void putEntitiesExtras(Bundle container, @Nullable Bundle entitiesExtras) {
- container.putParcelable(ENTITIES_EXTRAS, entitiesExtras);
- }
-
- /**
- * Returns {@code entities} information contained in a TextClassifier response object.
- *
- * @see {@link #putEntitiesExtras(Bundle, Bundle)}
- */
- @Nullable
- public static String getCopyText(Bundle container) {
- Bundle entitiesExtras = container.getParcelable(ENTITIES_EXTRAS);
- if (entitiesExtras == null) {
- return null;
- }
- return entitiesExtras.getString("text");
- }
-
- /** Returns {@code actionIntents} information contained in the TextClassification object. */
- @Nullable
- public static ArrayList<Intent> getActionsIntents(@Nullable TextClassification classification) {
- if (classification == null) {
- return null;
- }
- return classification.getExtras().getParcelableArrayList(ACTIONS_INTENTS);
- }
-
- /**
- * Returns the first action found in the {@code classification} object with an intent action
- * string, {@code intentAction}.
- */
- @Nullable
- @VisibleForTesting
- public static RemoteAction findAction(
- @Nullable TextClassification classification, @Nullable String intentAction) {
- if (classification == null || intentAction == null) {
- return null;
- }
- final ArrayList<Intent> actionIntents = getActionsIntents(classification);
- if (actionIntents != null) {
- final int size = actionIntents.size();
- for (int i = 0; i < size; i++) {
- final Intent intent = actionIntents.get(i);
- if (intent != null && intentAction.equals(intent.getAction())) {
- return classification.getActions().get(i);
- }
- }
- }
- return null;
- }
-
- /** Returns the first "translate" action found in the {@code classification} object. */
- @Nullable
- @VisibleForTesting
- public static RemoteAction findTranslateAction(@Nullable TextClassification classification) {
- return findAction(classification, Intent.ACTION_TRANSLATE);
- }
-
- /** Returns the entity type contained in the {@code extra}. */
- @Nullable
- @VisibleForTesting
- public static String getEntityType(@Nullable Bundle extra) {
- if (extra == null) {
- return null;
- }
- return extra.getString(ENTITY_TYPE);
- }
-
- /** Returns the score contained in the {@code extra}. */
- @VisibleForTesting
- public static float getScore(Bundle extra) {
- final int defaultValue = -1;
- if (extra == null) {
- return defaultValue;
- }
- return extra.getFloat(SCORE, defaultValue);
- }
-
- /** Returns the model name contained in the {@code extra}. */
- @Nullable
- public static String getModelName(@Nullable Bundle extra) {
- if (extra == null) {
- return null;
- }
- return extra.getString(MODEL_NAME);
- }
-
- /**
- * Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}.
- */
- public static void putEntities(
- Bundle container, @Nullable AnnotatorModel.ClassificationResult[] classifications) {
- if (classifications == null || classifications.length == 0) {
- return;
- }
- ArrayList<Bundle> entitiesBundle = new ArrayList<>();
- for (AnnotatorModel.ClassificationResult classification : classifications) {
- if (classification == null) {
- continue;
- }
- Bundle entityBundle = new Bundle();
- entityBundle.putString(ENTITY_TYPE, classification.getCollection());
- entityBundle.putByteArray(
- SERIALIZED_ENTITIES_DATA, classification.getSerializedEntityData());
- entitiesBundle.add(entityBundle);
- }
- if (!entitiesBundle.isEmpty()) {
- container.putParcelableArrayList(ENTITIES, entitiesBundle);
- }
- }
-
- /** Returns a list of entities contained in the {@code extra}. */
- @Nullable
- @VisibleForTesting
- public static List<Bundle> getEntities(Bundle container) {
- return container.getParcelableArrayList(ENTITIES);
- }
-
- /** Whether the annotator should populate serialized entity data into the result object. */
- public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
- return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
- }
-
- /**
- * To indicate whether the annotator should populate serialized entity data in the result
- * object.
- */
- @VisibleForTesting
- public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
- bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
- }
+ /**
+ * To indicate whether the annotator should populate serialized entity data in the result object.
+ */
+ @VisibleForTesting
+ public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
+ bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
+ }
}
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
index 69ae40e..1d8bf3e 100644
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -13,16 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
import android.text.TextUtils;
-
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
+import com.google.common.base.Splitter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
@@ -36,263 +34,262 @@
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import javax.annotation.Nullable;
/** Manages model files that are listed by the model files supplier. */
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
-public final class ModelFileManager {
- private static final String TAG = "ModelFileManager";
+final class ModelFileManager {
+ private static final String TAG = "ModelFileManager";
- private final Object mLock = new Object();
- private final Supplier<List<ModelFile>> mModelFileSupplier;
+ private final Object lock = new Object();
+ private final Supplier<List<ModelFile>> modelFileSupplier;
- private List<ModelFile> mModelFiles;
+ private List<ModelFile> modelFiles;
- public ModelFileManager(Supplier<List<ModelFile>> modelFileSupplier) {
- mModelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
+ public ModelFileManager(Supplier<List<ModelFile>> modelFileSupplier) {
+ this.modelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
+ }
+
+ /**
+ * Returns an unmodifiable list of model files listed by the given model files supplier.
+ *
+ * <p>The result is cached.
+ */
+ public List<ModelFile> listModelFiles() {
+ synchronized (lock) {
+ if (modelFiles == null) {
+ modelFiles = Collections.unmodifiableList(modelFileSupplier.get());
+ }
+ return modelFiles;
+ }
+ }
+
+ /**
+ * Returns the best model file for the given localelist, {@code null} if nothing is found.
+ *
+ * @param localeList the required locales, use {@code null} if there is no preference.
+ */
+ public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
+ final String languages =
+ localeList == null || localeList.isEmpty()
+ ? LocaleList.getDefault().toLanguageTags()
+ : localeList.toLanguageTags();
+ final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
+
+ ModelFile bestModel = null;
+ for (ModelFile model : listModelFiles()) {
+ if (model.isAnyLanguageSupported(languageRangeList)) {
+ if (model.isPreferredTo(bestModel)) {
+ bestModel = model;
+ }
+ }
+ }
+ return bestModel;
+ }
+
+ /** Default implementation of the model file supplier. */
+ public static final class ModelFileSupplierImpl implements Supplier<List<ModelFile>> {
+ private final File updatedModelFile;
+ private final File factoryModelDir;
+ private final Pattern modelFilenamePattern;
+ private final Function<Integer, Integer> versionSupplier;
+ private final Function<Integer, String> supportedLocalesSupplier;
+
+ public ModelFileSupplierImpl(
+ File factoryModelDir,
+ String factoryModelFileNameRegex,
+ File updatedModelFile,
+ Function<Integer, Integer> versionSupplier,
+ Function<Integer, String> supportedLocalesSupplier) {
+ this.updatedModelFile = Preconditions.checkNotNull(updatedModelFile);
+ this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
+ modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(factoryModelFileNameRegex));
+ this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
+ this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
}
- /**
- * Returns an unmodifiable list of model files listed by the given model files supplier.
- *
- * <p>The result is cached.
- */
- public List<ModelFile> listModelFiles() {
- synchronized (mLock) {
- if (mModelFiles == null) {
- mModelFiles = Collections.unmodifiableList(mModelFileSupplier.get());
- }
- return mModelFiles;
+ @Override
+ public List<ModelFile> get() {
+ final List<ModelFile> modelFiles = new ArrayList<>();
+ // The update model has the highest precedence.
+ if (updatedModelFile.exists()) {
+ final ModelFile updatedModel = createModelFile(updatedModelFile);
+ if (updatedModel != null) {
+ modelFiles.add(updatedModel);
}
+ }
+ // Factory models should never have overlapping locales, so the order doesn't matter.
+ if (factoryModelDir.exists() && factoryModelDir.isDirectory()) {
+ final File[] files = factoryModelDir.listFiles();
+ for (File file : files) {
+ final Matcher matcher = modelFilenamePattern.matcher(file.getName());
+ if (matcher.matches() && file.isFile()) {
+ final ModelFile model = createModelFile(file);
+ if (model != null) {
+ modelFiles.add(model);
+ }
+ }
+ }
+ }
+ return modelFiles;
}
- /**
- * Returns the best model file for the given localelist, {@code null} if nothing is found.
- *
- * @param localeList the required locales, use {@code null} if there is no preference.
- */
- public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
- final String languages =
- localeList == null || localeList.isEmpty()
- ? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags();
- final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
-
- ModelFile bestModel = null;
- for (ModelFile model : listModelFiles()) {
- if (model.isAnyLanguageSupported(languageRangeList)) {
- if (model.isPreferredTo(bestModel)) {
- bestModel = model;
- }
- }
+ /** Returns null if the path did not point to a compatible model. */
+ @Nullable
+ private ModelFile createModelFile(File file) {
+ if (!file.exists()) {
+ return null;
+ }
+ ParcelFileDescriptor modelFd = null;
+ try {
+ modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ if (modelFd == null) {
+ return null;
}
- return bestModel;
+ final int modelFdInt = modelFd.getFd();
+ final int version = versionSupplier.apply(modelFdInt);
+ final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt);
+ if (supportedLocalesStr.isEmpty()) {
+ TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
+ return null;
+ }
+ final List<Locale> supportedLocales = new ArrayList<>();
+ for (String langTag : Splitter.on(',').split(supportedLocalesStr)) {
+ supportedLocales.add(Locale.forLanguageTag(langTag));
+ }
+ return new ModelFile(
+ file,
+ version,
+ supportedLocales,
+ supportedLocalesStr,
+ ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
+ } catch (FileNotFoundException e) {
+ TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
+ return null;
+ } finally {
+ maybeCloseAndLogError(modelFd);
+ }
}
- /** Default implementation of the model file supplier. */
- public static final class ModelFileSupplierImpl implements Supplier<List<ModelFile>> {
- private final File mUpdatedModelFile;
- private final File mFactoryModelDir;
- private final Pattern mModelFilenamePattern;
- private final Function<Integer, Integer> mVersionSupplier;
- private final Function<Integer, String> mSupportedLocalesSupplier;
+ /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
+ }
+ try {
+ fd.close();
+ } catch (IOException e) {
+ TcLog.e(TAG, "Error closing file.", e);
+ }
+ }
+ }
- public ModelFileSupplierImpl(
- File factoryModelDir,
- String factoryModelFileNameRegex,
- File updatedModelFile,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- mUpdatedModelFile = Preconditions.checkNotNull(updatedModelFile);
- mFactoryModelDir = Preconditions.checkNotNull(factoryModelDir);
- mModelFilenamePattern =
- Pattern.compile(Preconditions.checkNotNull(factoryModelFileNameRegex));
- mVersionSupplier = Preconditions.checkNotNull(versionSupplier);
- mSupportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
- }
+ /** Describes TextClassifier model files on disk. */
+ public static final class ModelFile {
+ public static final String LANGUAGE_INDEPENDENT = "*";
- @Override
- public List<ModelFile> get() {
- final List<ModelFile> modelFiles = new ArrayList<>();
- // The update model has the highest precedence.
- if (mUpdatedModelFile.exists()) {
- final ModelFile updatedModel = createModelFile(mUpdatedModelFile);
- if (updatedModel != null) {
- modelFiles.add(updatedModel);
- }
- }
- // Factory models should never have overlapping locales, so the order doesn't matter.
- if (mFactoryModelDir.exists() && mFactoryModelDir.isDirectory()) {
- final File[] files = mFactoryModelDir.listFiles();
- for (File file : files) {
- final Matcher matcher = mModelFilenamePattern.matcher(file.getName());
- if (matcher.matches() && file.isFile()) {
- final ModelFile model = createModelFile(file);
- if (model != null) {
- modelFiles.add(model);
- }
- }
- }
- }
- return modelFiles;
- }
+ private final File file;
+ private final int version;
+ private final List<Locale> supportedLocales;
+ private final String supportedLocalesStr;
+ private final boolean languageIndependent;
- /** Returns null if the path did not point to a compatible model. */
- @Nullable
- private ModelFile createModelFile(File file) {
- if (!file.exists()) {
- return null;
- }
- ParcelFileDescriptor modelFd = null;
- try {
- modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- if (modelFd == null) {
- return null;
- }
- final int modelFdInt = modelFd.getFd();
- final int version = mVersionSupplier.apply(modelFdInt);
- final String supportedLocalesStr = mSupportedLocalesSupplier.apply(modelFdInt);
- if (supportedLocalesStr.isEmpty()) {
- TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
- return null;
- }
- final List<Locale> supportedLocales = new ArrayList<>();
- for (String langTag : supportedLocalesStr.split(",")) {
- supportedLocales.add(Locale.forLanguageTag(langTag));
- }
- return new ModelFile(
- file,
- version,
- supportedLocales,
- supportedLocalesStr,
- ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
- return null;
- } finally {
- maybeCloseAndLogError(modelFd);
- }
- }
-
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
- }
+ public ModelFile(
+ File file,
+ int version,
+ List<Locale> supportedLocales,
+ String supportedLocalesStr,
+ boolean languageIndependent) {
+ this.file = Preconditions.checkNotNull(file);
+ this.version = version;
+ this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
+ this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
+ this.languageIndependent = languageIndependent;
}
- /** Describes TextClassifier model files on disk. */
- public static final class ModelFile {
- public static final String LANGUAGE_INDEPENDENT = "*";
-
- private final File mFile;
- private final int mVersion;
- private final List<Locale> mSupportedLocales;
- private final String mSupportedLocalesStr;
- private final boolean mLanguageIndependent;
-
- public ModelFile(
- File file,
- int version,
- List<Locale> supportedLocales,
- String supportedLocalesStr,
- boolean languageIndependent) {
- mFile = Preconditions.checkNotNull(file);
- mVersion = version;
- mSupportedLocales = Preconditions.checkNotNull(supportedLocales);
- mSupportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
- mLanguageIndependent = languageIndependent;
- }
-
- /** Returns the absolute path to the model file. */
- public String getPath() {
- return mFile.getAbsolutePath();
- }
-
- /** Returns a name to use for id generation, effectively the name of the model file. */
- public String getName() {
- return mFile.getName();
- }
-
- /** Returns the version tag in the model's metadata. */
- public int getVersion() {
- return mVersion;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- Preconditions.checkNotNull(languageRanges);
- return mLanguageIndependent || Locale.lookup(languageRanges, mSupportedLocales) != null;
- }
-
- /** Returns an immutable lists of supported locales. */
- public List<Locale> getSupportedLocales() {
- return Collections.unmodifiableList(mSupportedLocales);
- }
-
- /** Returns the original supported locals string read from the model file. */
- public String getSupportedLocalesStr() {
- return mSupportedLocalesStr;
- }
-
- /** Returns if this model file is preferred to the given one. */
- public boolean isPreferredTo(@Nullable ModelFile model) {
- // A model is preferred to no model.
- if (model == null) {
- return true;
- }
-
- // A language-specific model is preferred to a language independent
- // model.
- if (!mLanguageIndependent && model.mLanguageIndependent) {
- return true;
- }
- if (mLanguageIndependent && !model.mLanguageIndependent) {
- return false;
- }
-
- // A higher-version model is preferred.
- if (mVersion > model.getVersion()) {
- return true;
- }
- return false;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(getPath());
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
- }
- if (other instanceof ModelFile) {
- final ModelFile otherModel = (ModelFile) other;
- return TextUtils.equals(getPath(), otherModel.getPath());
- }
- return false;
- }
-
- @Override
- public String toString() {
- final StringJoiner localesJoiner = new StringJoiner(",");
- for (Locale locale : mSupportedLocales) {
- localesJoiner.add(locale.toLanguageTag());
- }
- return String.format(
- Locale.US,
- "ModelFile { path=%s name=%s version=%d locales=%s }",
- getPath(),
- getName(),
- mVersion,
- localesJoiner.toString());
- }
+ /** Returns the absolute path to the model file. */
+ public String getPath() {
+ return file.getAbsolutePath();
}
+
+ /** Returns a name to use for id generation, effectively the name of the model file. */
+ public String getName() {
+ return file.getName();
+ }
+
+ /** Returns the version tag in the model's metadata. */
+ public int getVersion() {
+ return version;
+ }
+
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
+ }
+
+ /** Returns an immutable lists of supported locales. */
+ public List<Locale> getSupportedLocales() {
+ return Collections.unmodifiableList(supportedLocales);
+ }
+
+ /** Returns the original supported locals string read from the model file. */
+ public String getSupportedLocalesStr() {
+ return supportedLocalesStr;
+ }
+
+ /** Returns if this model file is preferred to the given one. */
+ public boolean isPreferredTo(@Nullable ModelFile model) {
+ // A model is preferred to no model.
+ if (model == null) {
+ return true;
+ }
+
+ // A language-specific model is preferred to a language independent
+ // model.
+ if (!languageIndependent && model.languageIndependent) {
+ return true;
+ }
+ if (languageIndependent && !model.languageIndependent) {
+ return false;
+ }
+
+ // A higher-version model is preferred.
+ if (version > model.getVersion()) {
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(getPath());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (other instanceof ModelFile) {
+ final ModelFile otherModel = (ModelFile) other;
+ return TextUtils.equals(getPath(), otherModel.getPath());
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ final StringJoiner localesJoiner = new StringJoiner(",");
+ for (Locale locale : supportedLocales) {
+ localesJoiner.add(locale.toLanguageTag());
+ }
+ return String.format(
+ Locale.US,
+ "ModelFile { path=%s name=%s version=%d locales=%s }",
+ getPath(),
+ getName(),
+ version,
+ localesJoiner);
+ }
+ }
}
diff --git a/java/src/com/android/textclassifier/StringUtils.java b/java/src/com/android/textclassifier/StringUtils.java
index d4968c8..a32d5e4 100644
--- a/java/src/com/android/textclassifier/StringUtils.java
+++ b/java/src/com/android/textclassifier/StringUtils.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,14 +11,13 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier;
import androidx.annotation.GuardedBy;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.text.BreakIterator;
/**
@@ -31,52 +30,54 @@
* Intended to be used only for TextClassifier purposes.
*/
public final class StringUtils {
- private static final String TAG = "StringUtils";
+ private static final String TAG = "StringUtils";
- @GuardedBy("WORD_ITERATOR")
- private static final BreakIterator WORD_ITERATOR = BreakIterator.getWordInstance();
+ @GuardedBy("WORD_ITERATOR")
+ private static final BreakIterator WORD_ITERATOR = BreakIterator.getWordInstance();
- /**
- * Returns the substring of {@code text} that contains at least text from index {@code start}
- * <i>(inclusive)</i> to index {@code end} <i><(exclusive)/i> with the goal of returning text
- * that is at least {@code minimumLength}. If {@code text} is not long enough, this will return
- * {@code text}. This method returns text at word boundaries.
- *
- * @param text the source text
- * @param start the start index of text that must be included
- * @param end the end index of text that must be included
- * @param minimumLength minimum length of text to return if {@code text} is long enough
- */
- public static String getSubString(String text, int start, int end, int minimumLength) {
- Preconditions.checkArgument(start >= 0);
- Preconditions.checkArgument(end <= text.length());
- Preconditions.checkArgument(start <= end);
+ /**
+ * Returns the substring of {@code text} that contains at least text from index {@code start}
+ * <i>(inclusive)</i> to index {@code end} <i><(exclusive)/i> with the goal of returning text that
+ * is at least {@code minimumLength}. If {@code text} is not long enough, this will return {@code
+ * text}. This method returns text at word boundaries.
+ *
+ * @param text the source text
+ * @param start the start index of text that must be included
+ * @param end the end index of text that must be included
+ * @param minimumLength minimum length of text to return if {@code text} is long enough
+ */
+ public static String getSubString(String text, int start, int end, int minimumLength) {
+ Preconditions.checkArgument(start >= 0);
+ Preconditions.checkArgument(end <= text.length());
+ Preconditions.checkArgument(start <= end);
- if (text.length() < minimumLength) {
- return text;
- }
-
- final int length = end - start;
- if (length >= minimumLength) {
- return text.substring(start, end);
- }
-
- final int offset = (minimumLength - length) / 2;
- int iterStart = Math.max(0, Math.min(start - offset, text.length() - minimumLength));
- int iterEnd = Math.min(text.length(), iterStart + minimumLength);
-
- synchronized (WORD_ITERATOR) {
- WORD_ITERATOR.setText(text);
- iterStart =
- WORD_ITERATOR.isBoundary(iterStart)
- ? iterStart
- : Math.max(0, WORD_ITERATOR.preceding(iterStart));
- iterEnd =
- WORD_ITERATOR.isBoundary(iterEnd)
- ? iterEnd
- : Math.max(iterEnd, WORD_ITERATOR.following(iterEnd));
- WORD_ITERATOR.setText("");
- return text.substring(iterStart, iterEnd);
- }
+ if (text.length() < minimumLength) {
+ return text;
}
+
+ final int length = end - start;
+ if (length >= minimumLength) {
+ return text.substring(start, end);
+ }
+
+ final int offset = (minimumLength - length) / 2;
+ int iterStart = Math.max(0, Math.min(start - offset, text.length() - minimumLength));
+ int iterEnd = Math.min(text.length(), iterStart + minimumLength);
+
+ synchronized (WORD_ITERATOR) {
+ WORD_ITERATOR.setText(text);
+ iterStart =
+ WORD_ITERATOR.isBoundary(iterStart)
+ ? iterStart
+ : Math.max(0, WORD_ITERATOR.preceding(iterStart));
+ iterEnd =
+ WORD_ITERATOR.isBoundary(iterEnd)
+ ? iterEnd
+ : Math.max(iterEnd, WORD_ITERATOR.following(iterEnd));
+ WORD_ITERATOR.setText("");
+ return text.substring(iterStart, iterEnd);
+ }
+ }
+
+ private StringUtils() {}
}
diff --git a/java/src/com/android/textclassifier/TcLog.java b/java/src/com/android/textclassifier/TcLog.java
index 0c664ed..581d660 100644
--- a/java/src/com/android/textclassifier/TcLog.java
+++ b/java/src/com/android/textclassifier/TcLog.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,34 +25,34 @@
* @hide
*/
public final class TcLog {
- private static final boolean USE_TC_TAG = true;
- public static final String TAG = "androidtc";
+ private static final boolean USE_TC_TAG = true;
+ public static final String TAG = "androidtc";
- /** true: Enables full logging. false: Limits logging to debug level. */
- public static final boolean ENABLE_FULL_LOGGING =
- android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
+ /** true: Enables full logging. false: Limits logging to debug level. */
+ public static final boolean ENABLE_FULL_LOGGING =
+ android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
- private TcLog() {}
+ private TcLog() {}
- public static void v(String tag, String msg) {
- if (ENABLE_FULL_LOGGING) {
- android.util.Log.v(getTag(tag), msg);
- }
+ public static void v(String tag, String msg) {
+ if (ENABLE_FULL_LOGGING) {
+ android.util.Log.v(getTag(tag), msg);
}
+ }
- public static void d(String tag, String msg) {
- android.util.Log.d(getTag(tag), msg);
- }
+ public static void d(String tag, String msg) {
+ android.util.Log.d(getTag(tag), msg);
+ }
- public static void w(String tag, String msg) {
- android.util.Log.w(getTag(tag), msg);
- }
+ public static void w(String tag, String msg) {
+ android.util.Log.w(getTag(tag), msg);
+ }
- public static void e(String tag, String msg, Throwable tr) {
- android.util.Log.e(getTag(tag), msg, tr);
- }
+ public static void e(String tag, String msg, Throwable tr) {
+ android.util.Log.e(getTag(tag), msg, tr);
+ }
- private static String getTag(String customTag) {
- return USE_TC_TAG ? TAG : customTag;
- }
+ private static String getTag(String customTag) {
+ return USE_TC_TAG ? TAG : customTag;
+ }
}
diff --git a/java/src/com/android/textclassifier/TextClassificationConstants.java b/java/src/com/android/textclassifier/TextClassificationConstants.java
index f70c5ac..11b5179 100644
--- a/java/src/com/android/textclassifier/TextClassificationConstants.java
+++ b/java/src/com/android/textclassifier/TextClassificationConstants.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,14 +19,13 @@
import android.provider.DeviceConfig;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
-
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.utils.IndentingPrintWriter;
-
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import javax.annotation.Nullable;
/**
* TextClassifier specific settings.
@@ -43,318 +42,311 @@
*/
// TODO: Rename to TextClassifierSettings.
public final class TextClassificationConstants {
- private static final String DELIMITER = ":";
+ private static final String DELIMITER = ":";
- /** Whether the user language profile feature is enabled. */
- private static final String USER_LANGUAGE_PROFILE_ENABLED = "user_language_profile_enabled";
- /** Max length of text that suggestSelection can accept. */
- private static final String SUGGEST_SELECTION_MAX_RANGE_LENGTH =
- "suggest_selection_max_range_length";
- /** Max length of text that classifyText can accept. */
- private static final String CLASSIFY_TEXT_MAX_RANGE_LENGTH = "classify_text_max_range_length";
- /** Max length of text that generateLinks can accept. */
- private static final String GENERATE_LINKS_MAX_TEXT_LENGTH = "generate_links_max_text_length";
- /** Sampling rate for generateLinks logging. */
- private static final String GENERATE_LINKS_LOG_SAMPLE_RATE = "generate_links_log_sample_rate";
- /**
- * Extra count that is added to some languages, e.g. system languages, when deducing the
- * frequent languages in {@link
- * com.android.textclassifier.ulp.LanguageProfileAnalyzer#getFrequentLanguages(int)}.
- */
- private static final String FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT =
- "frequent_languages_bootstrapping_count";
+ /** Whether the user language profile feature is enabled. */
+ private static final String USER_LANGUAGE_PROFILE_ENABLED = "user_language_profile_enabled";
+ /** Max length of text that suggestSelection can accept. */
+ private static final String SUGGEST_SELECTION_MAX_RANGE_LENGTH =
+ "suggest_selection_max_range_length";
+ /** Max length of text that classifyText can accept. */
+ private static final String CLASSIFY_TEXT_MAX_RANGE_LENGTH = "classify_text_max_range_length";
+ /** Max length of text that generateLinks can accept. */
+ private static final String GENERATE_LINKS_MAX_TEXT_LENGTH = "generate_links_max_text_length";
+ /** Sampling rate for generateLinks logging. */
+ private static final String GENERATE_LINKS_LOG_SAMPLE_RATE = "generate_links_log_sample_rate";
+ /**
+ * Extra count that is added to some languages, e.g. system languages, when deducing the frequent
+ * languages in {@link
+ * com.android.textclassifier.ulp.LanguageProfileAnalyzer#getFrequentLanguages(int)}.
+ */
+ private static final String FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT =
+ "frequent_languages_bootstrapping_count";
- /**
- * Default count for the language in the system settings while calculating {@code
- * LanguageProfileAnalyzer.getRecognizedLanguages()}
- */
- private static final String LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT =
- "language_proficiency_bootstrapping_count";
+ /**
+ * Default count for the language in the system settings while calculating {@code
+ * LanguageProfileAnalyzer.getRecognizedLanguages()}
+ */
+ private static final String LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT =
+ "language_proficiency_bootstrapping_count";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * hint is not given.
- */
- private static final String ENTITY_LIST_DEFAULT = "entity_list_default";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * the text is in a not editable UI widget.
- */
- private static final String ENTITY_LIST_NOT_EDITABLE = "entity_list_not_editable";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * the text is in an editable UI widget.
- */
- private static final String ENTITY_LIST_EDITABLE = "entity_list_editable";
- /**
- * A colon(:) separated string that specifies the default action types for
- * suggestConversationActions when the suggestions are used in an app.
- */
- private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
- "in_app_conversation_action_types_default";
- /**
- * A colon(:) separated string that specifies the default action types for
- * suggestConversationActions when the suggestions are used in a notification.
- */
- private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
- "notification_conversation_action_types_default";
- /** Threshold to accept a suggested language from LangID model. */
- private static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
- /** Whether to enable {@link com.android.textclassifier.intent.TemplateIntentFactory}. */
- private static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
- /** Whether to enable "translate" action in classifyText. */
- private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
- "translate_in_classification_enabled";
- /**
- * Whether to detect the languages of the text in request by using langId for the native model.
- */
- private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
- "detect_languages_from_text_enabled";
- /**
- * A colon(:) separated string that specifies the configuration to use when including
- * surrounding context text in language detection queries.
- *
- * <p>Format= minimumTextSize<int>:penalizeRatio<float>:textScoreRatio<float>
- *
- * <p>e.g. 20:1.0:0.4
- *
- * <p>Accept all text lengths with minimumTextSize=0
- *
- * <p>Reject all text less than minimumTextSize with penalizeRatio=0
- *
- * @see {@code TextClassifierImpl#detectLanguages(String, int, int)} for reference.
- */
- private static final String LANG_ID_CONTEXT_SETTINGS = "lang_id_context_settings";
- /** Default threshold to translate the language of the context the user selects */
- private static final String TRANSLATE_ACTION_THRESHOLD = "translate_action_threshold";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * hint is not given.
+ */
+ private static final String ENTITY_LIST_DEFAULT = "entity_list_default";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in a not editable UI widget.
+ */
+ private static final String ENTITY_LIST_NOT_EDITABLE = "entity_list_not_editable";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in an editable UI widget.
+ */
+ private static final String ENTITY_LIST_EDITABLE = "entity_list_editable";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in an app.
+ */
+ private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "in_app_conversation_action_types_default";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in a notification.
+ */
+ private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "notification_conversation_action_types_default";
+ /** Threshold to accept a suggested language from LangID model. */
+ private static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
+ /** Whether to enable {@link com.android.textclassifier.intent.TemplateIntentFactory}. */
+ private static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
+ /** Whether to enable "translate" action in classifyText. */
+ private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
+ "translate_in_classification_enabled";
+ /**
+ * Whether to detect the languages of the text in request by using langId for the native model.
+ */
+ private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
+ "detect_languages_from_text_enabled";
+ /**
+ * A colon(:) separated string that specifies the configuration to use when including surrounding
+ * context text in language detection queries.
+ *
+ * <p>Format= minimumTextSize<int>:penalizeRatio<float>:textScoreRatio<float>
+ *
+ * <p>e.g. 20:1.0:0.4
+ *
+ * <p>Accept all text lengths with minimumTextSize=0
+ *
+ * <p>Reject all text less than minimumTextSize with penalizeRatio=0
+ *
+ * @see {@code TextClassifierImpl#detectLanguages(String, int, int)} for reference.
+ */
+ private static final String LANG_ID_CONTEXT_SETTINGS = "lang_id_context_settings";
+ /** Default threshold to translate the language of the context the user selects */
+ private static final String TRANSLATE_ACTION_THRESHOLD = "translate_action_threshold";
- // Sync this with ConversationAction.TYPE_ADD_CONTACT;
- public static final String TYPE_ADD_CONTACT = "add_contact";
- // Sync this with ConversationAction.COPY;
- public static final String TYPE_COPY = "copy";
+ // Sync this with ConversationAction.TYPE_ADD_CONTACT;
+ public static final String TYPE_ADD_CONTACT = "add_contact";
+ // Sync this with ConversationAction.COPY;
+ public static final String TYPE_COPY = "copy";
- private static final int SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
- private static final int CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
- private static final int GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT = 100 * 1000;
- private static final int GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT = 100;
- private static final int FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT = 100;
- private static final int LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT = 100;
+ private static final int SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT = 100 * 1000;
+ private static final int GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT = 100;
+ private static final int FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT = 100;
+ private static final int LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT = 100;
- private static final List<String> ENTITY_LIST_DEFAULT_VALUE =
- Arrays.asList(
- TextClassifier.TYPE_ADDRESS,
- TextClassifier.TYPE_EMAIL,
- TextClassifier.TYPE_PHONE,
- TextClassifier.TYPE_URL,
- TextClassifier.TYPE_DATE,
- TextClassifier.TYPE_DATE_TIME,
- TextClassifier.TYPE_FLIGHT_NUMBER);
- private static final List<String> CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
- Arrays.asList(
- ConversationAction.TYPE_TEXT_REPLY,
- ConversationAction.TYPE_CREATE_REMINDER,
- ConversationAction.TYPE_CALL_PHONE,
- ConversationAction.TYPE_OPEN_URL,
- ConversationAction.TYPE_SEND_EMAIL,
- ConversationAction.TYPE_SEND_SMS,
- ConversationAction.TYPE_TRACK_FLIGHT,
- ConversationAction.TYPE_VIEW_CALENDAR,
- ConversationAction.TYPE_VIEW_MAP,
- TYPE_ADD_CONTACT,
- TYPE_COPY);
- /**
- * < 0 : Not set. Use value from LangId model. 0 - 1: Override value in LangId model.
- *
- * @see EntityConfidence
- */
- private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
+ private static final ImmutableList<String> ENTITY_LIST_DEFAULT_VALUE =
+ ImmutableList.of(
+ TextClassifier.TYPE_ADDRESS,
+ TextClassifier.TYPE_EMAIL,
+ TextClassifier.TYPE_PHONE,
+ TextClassifier.TYPE_URL,
+ TextClassifier.TYPE_DATE,
+ TextClassifier.TYPE_DATE_TIME,
+ TextClassifier.TYPE_FLIGHT_NUMBER);
+ private static final ImmutableList<String> CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
+ ImmutableList.of(
+ ConversationAction.TYPE_TEXT_REPLY,
+ ConversationAction.TYPE_CREATE_REMINDER,
+ ConversationAction.TYPE_CALL_PHONE,
+ ConversationAction.TYPE_OPEN_URL,
+ ConversationAction.TYPE_SEND_EMAIL,
+ ConversationAction.TYPE_SEND_SMS,
+ ConversationAction.TYPE_TRACK_FLIGHT,
+ ConversationAction.TYPE_VIEW_CALENDAR,
+ ConversationAction.TYPE_VIEW_MAP,
+ TYPE_ADD_CONTACT,
+ TYPE_COPY);
+ /**
+ * < 0 : Not set. Use value from LangId model. 0 - 1: Override value in LangId model.
+ *
+ * @see EntityConfidence
+ */
+ private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
- private static final float TRANSLATE_ACTION_THRESHOLD_DEFAULT = 0.5f;
+ private static final float TRANSLATE_ACTION_THRESHOLD_DEFAULT = 0.5f;
- private static final boolean USER_LANGUAGE_PROFILE_ENABLED_DEFAULT = true;
- private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
- private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
- private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
- private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
+ private static final boolean USER_LANGUAGE_PROFILE_ENABLED_DEFAULT = true;
+ private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
+ private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
+ private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
+ private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
- public int getSuggestSelectionMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
+ public int getSuggestSelectionMaxRangeLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ SUGGEST_SELECTION_MAX_RANGE_LENGTH,
+ SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getClassifyTextMaxRangeLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CLASSIFY_TEXT_MAX_RANGE_LENGTH,
+ CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksMaxTextLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ GENERATE_LINKS_MAX_TEXT_LENGTH,
+ GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksLogSampleRate() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ GENERATE_LINKS_LOG_SAMPLE_RATE,
+ GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
+ }
+
+ public int getFrequentLanguagesBootstrappingCount() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT,
+ FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT);
+ }
+
+ public int getLanguageProficiencyBootstrappingCount() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT,
+ LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT);
+ }
+
+ public List<String> getEntityListDefault() {
+ return getDeviceConfigStringList(ENTITY_LIST_DEFAULT, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListNotEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_NOT_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getInAppConversationActionTypes() {
+ return getDeviceConfigStringList(
+ IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public List<String> getNotificationConversationActionTypes() {
+ return getDeviceConfigStringList(
+ NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public float getLangIdThresholdOverride() {
+ return DeviceConfig.getFloat(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ LANG_ID_THRESHOLD_OVERRIDE,
+ LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
+ }
+
+ public float getTranslateActionThreshold() {
+ return DeviceConfig.getFloat(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TRANSLATE_ACTION_THRESHOLD,
+ TRANSLATE_ACTION_THRESHOLD_DEFAULT);
+ }
+
+ public boolean isUserLanguageProfileEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ USER_LANGUAGE_PROFILE_ENABLED,
+ USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
+ }
+
+ public boolean isTemplateIntentFactoryEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TEMPLATE_INTENT_FACTORY_ENABLED,
+ TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
+ }
+
+ public boolean isTranslateInClassificationEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
+ }
+
+ public boolean isDetectLanguagesFromTextEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ DETECT_LANGUAGES_FROM_TEXT_ENABLED,
+ DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
+ }
+
+ public float[] getLangIdContextSettings() {
+ return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
+ }
+
+ void dump(IndentingPrintWriter pw) {
+ pw.println("TextClassificationConstants:");
+ pw.increaseIndent();
+ pw.printPair("classify_text_max_range_length", getClassifyTextMaxRangeLength());
+ pw.printPair("detect_language_from_text_enabled", isDetectLanguagesFromTextEnabled());
+ pw.printPair("entity_list_default", getEntityListDefault());
+ pw.printPair("entity_list_editable", getEntityListEditable());
+ pw.printPair("entity_list_not_editable", getEntityListNotEditable());
+ pw.printPair("generate_links_log_sample_rate", getGenerateLinksLogSampleRate());
+ pw.printPair(
+ "frequent_languages_bootstrapping_count", getFrequentLanguagesBootstrappingCount());
+ pw.printPair("generate_links_max_text_length", getGenerateLinksMaxTextLength());
+ pw.printPair("in_app_conversation_action_types_default", getInAppConversationActionTypes());
+ pw.printPair("lang_id_context_settings", Arrays.toString(getLangIdContextSettings()));
+ pw.printPair("lang_id_threshold_override", getLangIdThresholdOverride());
+ pw.printPair("translate_action_threshold", getTranslateActionThreshold());
+ pw.printPair(
+ "notification_conversation_action_types_default", getNotificationConversationActionTypes());
+ pw.printPair("suggest_selection_max_range_length", getSuggestSelectionMaxRangeLength());
+ pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
+ pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
+ pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
+ pw.printPair(
+ "language proficiency bootstrapping count", getLanguageProficiencyBootstrappingCount());
+ pw.decreaseIndent();
+ }
+
+ private static List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
+ return parse(
+ DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ }
+
+ private static float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
+ return parse(
+ DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ }
+
+ private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
+ if (listStr != null) {
+ return Collections.unmodifiableList(Arrays.asList(listStr.split(DELIMITER)));
}
+ return defaultValue;
+ }
- public int getClassifyTextMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
- }
-
- public int getGenerateLinksMaxTextLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_MAX_TEXT_LENGTH,
- GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
- }
-
- public int getGenerateLinksLogSampleRate() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_LOG_SAMPLE_RATE,
- GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
- }
-
- public int getFrequentLanguagesBootstrappingCount() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT,
- FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT);
- }
-
- public int getLanguageProficiencyBootstrappingCount() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT,
- LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT);
- }
-
- public List<String> getEntityListDefault() {
- return getDeviceConfigStringList(ENTITY_LIST_DEFAULT, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getEntityListNotEditable() {
- return getDeviceConfigStringList(ENTITY_LIST_NOT_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getEntityListEditable() {
- return getDeviceConfigStringList(ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getInAppConversationActionTypes() {
- return getDeviceConfigStringList(
- IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT,
- CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
- }
-
- public List<String> getNotificationConversationActionTypes() {
- return getDeviceConfigStringList(
- NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT,
- CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
- }
-
- public float getLangIdThresholdOverride() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- LANG_ID_THRESHOLD_OVERRIDE,
- LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
- }
-
- public float getTranslateActionThreshold() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TRANSLATE_ACTION_THRESHOLD,
- TRANSLATE_ACTION_THRESHOLD_DEFAULT);
- }
-
- public boolean isUserLanguageProfileEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- USER_LANGUAGE_PROFILE_ENABLED,
- USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
- }
-
- public boolean isTemplateIntentFactoryEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TEMPLATE_INTENT_FACTORY_ENABLED,
- TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
- }
-
- public boolean isTranslateInClassificationEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TRANSLATE_IN_CLASSIFICATION_ENABLED,
- TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
- }
-
- public boolean isDetectLanguagesFromTextEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
- }
-
- public float[] getLangIdContextSettings() {
- return getDeviceConfigFloatArray(
- LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
- }
-
- void dump(IndentingPrintWriter pw) {
- pw.println("TextClassificationConstants:");
- pw.increaseIndent();
- pw.printPair("classify_text_max_range_length", getClassifyTextMaxRangeLength());
- pw.printPair("detect_language_from_text_enabled", isDetectLanguagesFromTextEnabled());
- pw.printPair("entity_list_default", getEntityListDefault());
- pw.printPair("entity_list_editable", getEntityListEditable());
- pw.printPair("entity_list_not_editable", getEntityListNotEditable());
- pw.printPair("generate_links_log_sample_rate", getGenerateLinksLogSampleRate());
- pw.printPair(
- "frequent_languages_bootstrapping_count", getFrequentLanguagesBootstrappingCount());
- pw.printPair("generate_links_max_text_length", getGenerateLinksMaxTextLength());
- pw.printPair("in_app_conversation_action_types_default", getInAppConversationActionTypes());
- pw.printPair("lang_id_context_settings", Arrays.toString(getLangIdContextSettings()));
- pw.printPair("lang_id_threshold_override", getLangIdThresholdOverride());
- pw.printPair("translate_action_threshold", getTranslateActionThreshold());
- pw.printPair(
- "notification_conversation_action_types_default",
- getNotificationConversationActionTypes());
- pw.printPair("suggest_selection_max_range_length", getSuggestSelectionMaxRangeLength());
- pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
- pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
- pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
- pw.printPair(
- "language proficiency bootstrapping count",
- getLanguageProficiencyBootstrappingCount());
- pw.decreaseIndent();
- }
-
- private static List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null),
- defaultValue);
- }
-
- private static float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null),
- defaultValue);
- }
-
- private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
- if (listStr != null) {
- return Collections.unmodifiableList(Arrays.asList(listStr.split(DELIMITER)));
- }
+ private static float[] parse(@Nullable String arrayStr, float[] defaultValue) {
+ if (arrayStr != null) {
+ final List<String> split = Splitter.onPattern(DELIMITER).splitToList(arrayStr);
+ if (split.size() != defaultValue.length) {
return defaultValue;
- }
-
- private static float[] parse(@Nullable String arrayStr, float[] defaultValue) {
- if (arrayStr != null) {
- final String[] split = arrayStr.split(DELIMITER);
- if (split.length != defaultValue.length) {
- return defaultValue;
- }
- final float[] result = new float[split.length];
- for (int i = 0; i < split.length; i++) {
- try {
- result[i] = Float.parseFloat(split[i]);
- } catch (NumberFormatException e) {
- return defaultValue;
- }
- }
- return result;
- } else {
- return defaultValue;
+ }
+ final float[] result = new float[split.size()];
+ for (int i = 0; i < split.size(); i++) {
+ try {
+ result[i] = Float.parseFloat(split.get(i));
+ } catch (NumberFormatException e) {
+ return defaultValue;
}
+ }
+ return result;
+ } else {
+ return defaultValue;
}
+ }
}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index f385684..9babf30 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -37,15 +37,10 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
-
import androidx.annotation.GuardedBy;
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
import androidx.annotation.WorkerThread;
import androidx.collection.ArraySet;
import androidx.core.util.Pair;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.ActionsModelParamsSupplier.ActionsModelParams;
import com.android.textclassifier.intent.ClassificationIntentFactory;
import com.android.textclassifier.intent.LabeledIntent;
@@ -59,17 +54,16 @@
import com.android.textclassifier.ulp.LanguageProfileAnalyzer;
import com.android.textclassifier.ulp.LanguageProfileUpdater;
import com.android.textclassifier.utils.IndentingPrintWriter;
-
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
-import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
-
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.time.Instant;
+import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
@@ -80,6 +74,7 @@
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.function.Supplier;
+import javax.annotation.Nullable;
/**
* Default implementation of the {@link TextClassifier} interface.
@@ -90,901 +85,872 @@
*/
final class TextClassifierImpl {
- private static final String TAG = "TextClassifierImpl";
+ private static final String TAG = "TextClassifierImpl";
- private static final boolean DEBUG = false;
+ private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
+ // Annotator
+ private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
+ "textclassifier\\.(.*)\\.model";
+ private static final File ANNOTATOR_UPDATED_MODEL_FILE =
+ new File("/data/misc/textclassifier/textclassifier.model");
- private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
- // Annotator
- private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
- "textclassifier\\.(.*)\\.model";
- private static final File ANNOTATOR_UPDATED_MODEL_FILE =
- new File("/data/misc/textclassifier/textclassifier.model");
+ // LangID
+ private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
+ private static final File UPDATED_LANG_ID_MODEL_FILE =
+ new File("/data/misc/textclassifier/lang_id.model");
- // LangID
- private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
- private static final File UPDATED_LANG_ID_MODEL_FILE =
- new File("/data/misc/textclassifier/lang_id.model");
+ // Actions
+ private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
+ "actions_suggestions\\.(.*)\\.model";
+ private static final File UPDATED_ACTIONS_MODEL =
+ new File("/data/misc/textclassifier/actions_suggestions.model");
- // Actions
- private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
- "actions_suggestions\\.(.*)\\.model";
- private static final File UPDATED_ACTIONS_MODEL =
- new File("/data/misc/textclassifier/actions_suggestions.model");
+ private final Context context;
+ private final TextClassifier fallback;
+ private final GenerateLinksLogger generateLinksLogger;
- private final Context mContext;
- private final TextClassifier mFallback;
- private final GenerateLinksLogger mGenerateLinksLogger;
+ private final Object lock = new Object();
- private final Object mLock = new Object();
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile annotatorModelInUse;
- @GuardedBy("mLock")
- private ModelFileManager.ModelFile mAnnotatorModelInUse;
+ @GuardedBy("lock")
+ private AnnotatorModel annotatorImpl;
- @GuardedBy("mLock")
- private AnnotatorModel mAnnotatorImpl;
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile langIdModelInUse;
- @GuardedBy("mLock")
- private ModelFileManager.ModelFile mLangIdModelInUse;
+ @GuardedBy("lock")
+ private LangIdModel langIdImpl;
- @GuardedBy("mLock")
- private LangIdModel mLangIdImpl;
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile actionModelInUse;
- @GuardedBy("mLock")
- private ModelFileManager.ModelFile mActionModelInUse;
+ @GuardedBy("lock")
+ private ActionsSuggestionsModel actionsImpl;
- @GuardedBy("mLock")
- private ActionsSuggestionsModel mActionsImpl;
+ private final TextClassifierEventLogger textClassifierEventLogger =
+ new TextClassifierEventLogger();
- private final TextClassifierEventLogger mTextClassifierEventLogger =
- new TextClassifierEventLogger();
+ private final TextClassificationConstants settings;
- private final TextClassificationConstants mSettings;
+ private final ModelFileManager annotatorModelFileManager;
+ private final ModelFileManager langIdModelFileManager;
+ private final ModelFileManager actionsModelFileManager;
+ private final LanguageProfileUpdater languageProfileUpdater;
+ private final LanguageProfileAnalyzer languageProfileAnalyzer;
+ private final ClassificationIntentFactory classificationIntentFactory;
+ private final TemplateIntentFactory templateIntentFactory;
+ private final Supplier<ActionsModelParams> actionsModelParamsSupplier;
- private final ModelFileManager mAnnotatorModelFileManager;
- private final ModelFileManager mLangIdModelFileManager;
- private final ModelFileManager mActionsModelFileManager;
- private final LanguageProfileUpdater mLanguageProfileUpdater;
- private final LanguageProfileAnalyzer mLanguageProfileAnalyzer;
- private final ClassificationIntentFactory mClassificationIntentFactory;
- private final TemplateIntentFactory mTemplateIntentFactory;
- private final Supplier<ActionsModelParams> mActionsModelParamsSupplier;
+ TextClassifierImpl(
+ Context context, TextClassificationConstants settings, TextClassifier fallback) {
+ this.context = Preconditions.checkNotNull(context);
+ this.fallback = Preconditions.checkNotNull(fallback);
+ this.settings = Preconditions.checkNotNull(settings);
+ generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
+ annotatorModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
+ ANNOTATOR_UPDATED_MODEL_FILE,
+ AnnotatorModel::getVersion,
+ AnnotatorModel::getLocales));
+ langIdModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_LANG_ID_MODEL_FILE,
+ LangIdModel::getVersion,
+ fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
+ actionsModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_ACTIONS_MODEL,
+ ActionsSuggestionsModel::getVersion,
+ ActionsSuggestionsModel::getLocales));
+ languageProfileUpdater =
+ new LanguageProfileUpdater(
+ this.context, MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()));
+ languageProfileAnalyzer = LanguageProfileAnalyzer.create(context, this.settings);
- TextClassifierImpl(
- Context context, TextClassificationConstants settings, TextClassifier fallback) {
- mContext = Preconditions.checkNotNull(context);
- mFallback = Preconditions.checkNotNull(fallback);
- mSettings = Preconditions.checkNotNull(settings);
- mGenerateLinksLogger = new GenerateLinksLogger(mSettings.getGenerateLinksLogSampleRate());
- mAnnotatorModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
- ANNOTATOR_UPDATED_MODEL_FILE,
- AnnotatorModel::getVersion,
- AnnotatorModel::getLocales));
- mLangIdModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_LANG_ID_MODEL_FILE,
- LangIdModel::getVersion,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
- mActionsModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_ACTIONS_MODEL,
- ActionsSuggestionsModel::getVersion,
- ActionsSuggestionsModel::getLocales));
- mLanguageProfileUpdater =
- new LanguageProfileUpdater(
- mContext,
- MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()));
- mLanguageProfileAnalyzer = LanguageProfileAnalyzer.create(context, mSettings);
+ templateIntentFactory = new TemplateIntentFactory();
+ classificationIntentFactory =
+ this.settings.isTemplateIntentFactoryEnabled()
+ ? new TemplateClassificationIntentFactory(
+ templateIntentFactory, new LegacyClassificationIntentFactory())
+ : new LegacyClassificationIntentFactory();
+ actionsModelParamsSupplier =
+ new ActionsModelParamsSupplier(
+ this.context,
+ () -> {
+ synchronized (lock) {
+ // Clear actionsImpl here, so that we will create a new
+ // ActionsSuggestionsModel object with the new flag in the next
+ // request.
+ actionsImpl = null;
+ actionModelInUse = null;
+ }
+ });
+ }
- mTemplateIntentFactory = new TemplateIntentFactory();
- mClassificationIntentFactory =
- mSettings.isTemplateIntentFactoryEnabled()
- ? new TemplateClassificationIntentFactory(
- mTemplateIntentFactory, new LegacyClassificationIntentFactory())
- : new LegacyClassificationIntentFactory();
- mActionsModelParamsSupplier =
- new ActionsModelParamsSupplier(
- mContext,
- () -> {
- synchronized (mLock) {
- // Clear mActionsImpl here, so that we will create a new
- // ActionsSuggestionsModel object with the new flag in the next
- // request.
- mActionsImpl = null;
- mActionModelInUse = null;
- }
- });
- }
+ TextClassifierImpl(Context context, TextClassificationConstants settings) {
+ this(context, settings, TextClassifier.NO_OP);
+ }
- TextClassifierImpl(Context context, TextClassificationConstants settings) {
- this(context, settings, TextClassifier.NO_OP);
- }
-
- @WorkerThread
- TextSelection suggestSelection(TextSelection.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0
- && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final String detectLanguageTags =
- String.join(
- ",",
- detectLanguages(request.getText(), getLangIdThreshold())
- .getEntities());
- final ZonedDateTime refTime = ZonedDateTime.now();
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final int[] startEnd =
- annotatorImpl.suggestSelection(
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- new AnnotatorModel.SelectionOptions(
- localesString, detectLanguageTags));
- final int start = startEnd[0];
- final int end = startEnd[1];
- if (start < end
- && start >= 0
- && end <= string.length()
- && start <= request.getStartIndex()
- && end >= request.getEndIndex()) {
- final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
- final AnnotatorModel.ClassificationResult[] results =
- annotatorImpl.classifyText(
- string,
- start,
- end,
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- detectLanguageTags),
- // Passing null here to suppress intent generation
- // TODO: Use an explicit flag to suppress it.
- /* appContext */ null,
- /* deviceLocales */ null);
- final int size = results.length;
- for (int i = 0; i < size; i++) {
- tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
- }
- return tsBuilder
- .setId(createId(string, request.getStartIndex(), request.getEndIndex()))
- .build();
- } else {
- // We can not trust the result. Log the issue and ignore the result.
- TcLog.d(TAG, "Got bad indices for input text. Ignoring result.");
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(
- TAG,
- "Error suggesting selection for text. No changes to selection suggested.",
- t);
+ @WorkerThread
+ TextSelection suggestSelection(TextSelection.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ if (string.length() > 0 && rangeLength <= settings.getSuggestSelectionMaxRangeLength()) {
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final String detectLanguageTags =
+ String.join(
+ ",", detectLanguages(request.getText(), getLangIdThreshold()).getEntities());
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final int[] startEnd =
+ annotatorImpl.suggestSelection(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
+ final int start = startEnd[0];
+ final int end = startEnd[1];
+ if (start < end
+ && start >= 0
+ && end <= string.length()
+ && start <= request.getStartIndex()
+ && end >= request.getEndIndex()) {
+ final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
+ final AnnotatorModel.ClassificationResult[] results =
+ annotatorImpl.classifyText(
+ string,
+ start,
+ end,
+ new AnnotatorModel.ClassificationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ detectLanguageTags),
+ // Passing null here to suppress intent generation
+ // TODO: Use an explicit flag to suppress it.
+ /* appContext */ null,
+ /* deviceLocales */ null);
+ final int size = results.length;
+ for (int i = 0; i < size; i++) {
+ tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
+ }
+ return tsBuilder
+ .setId(createId(string, request.getStartIndex(), request.getEndIndex()))
+ .build();
+ } else {
+ // We can not trust the result. Log the issue and ignore the result.
+ TcLog.d(TAG, "Got bad indices for input text. Ignoring result.");
}
- // Getting here means something went wrong, return a NO_OP result.
- return mFallback.suggestSelection(request);
+ }
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error suggesting selection for text. No changes to selection suggested.", t);
}
+ // Getting here means something went wrong, return a NO_OP result.
+ return fallback.suggestSelection(request);
+ }
- @WorkerThread
- TextClassification classifyText(TextClassification.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- List<String> detectLanguageTags =
- detectLanguages(request.getText(), getLangIdThreshold()).getEntities();
- if (mSettings.isUserLanguageProfileEnabled()) {
- ListenableFuture<Void> ignoredResult =
- mLanguageProfileUpdater.updateFromClassifyTextAsync(detectLanguageTags);
- }
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0 && rangeLength <= mSettings.getClassifyTextMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final ZonedDateTime refTime =
- request.getReferenceTime() != null
- ? request.getReferenceTime()
- : ZonedDateTime.now();
- final AnnotatorModel.ClassificationResult[] results =
- getAnnotatorImpl(request.getDefaultLocales())
- .classifyText(
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- String.join(",", detectLanguageTags)),
- mContext,
- getResourceLocalesString());
- if (results.length > 0) {
- return createClassificationResult(
- results,
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- refTime.toInstant());
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting text classification info.", t);
+ @WorkerThread
+ TextClassification classifyText(TextClassification.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ List<String> detectLanguageTags =
+ detectLanguages(request.getText(), getLangIdThreshold()).getEntities();
+ if (settings.isUserLanguageProfileEnabled()) {
+ languageProfileUpdater.updateFromClassifyTextAsync(detectLanguageTags);
+ }
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ if (string.length() > 0 && rangeLength <= settings.getClassifyTextMaxRangeLength()) {
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final ZonedDateTime refTime =
+ request.getReferenceTime() != null
+ ? request.getReferenceTime()
+ : ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel.ClassificationResult[] results =
+ getAnnotatorImpl(request.getDefaultLocales())
+ .classifyText(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ new AnnotatorModel.ClassificationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ String.join(",", detectLanguageTags)),
+ context,
+ getResourceLocalesString());
+ if (results.length > 0) {
+ return createClassificationResult(
+ results, string, request.getStartIndex(), request.getEndIndex(), refTime.toInstant());
}
- // Getting here means something went wrong, return a NO_OP result.
- return mFallback.classifyText(request);
+ }
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error getting text classification info.", t);
}
+ // Getting here means something went wrong, return a NO_OP result.
+ return fallback.classifyText(request);
+ }
- @WorkerThread
- TextLinks generateLinks(@NonNull TextLinks.Request request) {
- Preconditions.checkNotNull(request);
- Preconditions.checkArgumentInRange(
- request.getText().length(), 0, getMaxGenerateLinksTextLength(), "text.length()");
- checkMainThread();
+ @WorkerThread
+ TextLinks generateLinks(TextLinks.Request request) {
+ Preconditions.checkNotNull(request);
+ Preconditions.checkArgument(
+ request.getText().length() <= getMaxGenerateLinksTextLength(),
+ "text.length() cannot be greater than %s",
+ getMaxGenerateLinksTextLength());
+ checkMainThread();
- final String textString = request.getText().toString();
- final TextLinks.Builder builder = new TextLinks.Builder(textString);
+ final String textString = request.getText().toString();
+ final TextLinks.Builder builder = new TextLinks.Builder(textString);
- try {
- final long startTimeMs = System.currentTimeMillis();
- final ZonedDateTime refTime = ZonedDateTime.now();
- final Collection<String> entitiesToIdentify =
- request.getEntityConfig() != null
- ? request.getEntityConfig()
- .resolveEntityListModifications(
- getEntitiesForHints(
- request.getEntityConfig().getHints()))
- : mSettings.getEntityListDefault();
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final String detectLanguageTags =
- String.join(
- ",",
- detectLanguages(request.getText(), getLangIdThreshold()).getEntities());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final boolean isSerializedEntityDataEnabled =
- ExtrasUtils.isSerializedEntityDataEnabled(request);
- final AnnotatorModel.AnnotatedSpan[] annotations =
- annotatorImpl.annotate(
- textString,
- new AnnotatorModel.AnnotationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- detectLanguageTags,
- entitiesToIdentify,
- AnnotatorModel.AnnotationUsecase.SMART.getValue(),
- isSerializedEntityDataEnabled));
- for (AnnotatorModel.AnnotatedSpan span : annotations) {
- final AnnotatorModel.ClassificationResult[] results = span.getClassification();
- if (results.length == 0
- || !entitiesToIdentify.contains(results[0].getCollection())) {
- continue;
- }
- final Map<String, Float> entityScores = new ArrayMap<>();
- for (int i = 0; i < results.length; i++) {
- entityScores.put(results[i].getCollection(), results[i].getScore());
- }
- Bundle extras = new Bundle();
- if (isSerializedEntityDataEnabled) {
- ExtrasUtils.putEntities(extras, results);
- }
- builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
- }
- final TextLinks links = builder.build();
- final long endTimeMs = System.currentTimeMillis();
- final String callingPackageName =
- request.getCallingPackageName() == null
- ? mContext.getPackageName() // local (in process) TC.
- : request.getCallingPackageName();
- mGenerateLinksLogger.logGenerateLinks(
- request.getText(), links, callingPackageName, endTimeMs - startTimeMs);
- return links;
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting links info.", t);
+ try {
+ final long startTimeMs = System.currentTimeMillis();
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final Collection<String> entitiesToIdentify =
+ request.getEntityConfig() != null
+ ? request
+ .getEntityConfig()
+ .resolveEntityListModifications(
+ getEntitiesForHints(request.getEntityConfig().getHints()))
+ : settings.getEntityListDefault();
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final String detectLanguageTags =
+ String.join(",", detectLanguages(request.getText(), getLangIdThreshold()).getEntities());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final boolean isSerializedEntityDataEnabled =
+ ExtrasUtils.isSerializedEntityDataEnabled(request);
+ final AnnotatorModel.AnnotatedSpan[] annotations =
+ annotatorImpl.annotate(
+ textString,
+ new AnnotatorModel.AnnotationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ detectLanguageTags,
+ entitiesToIdentify,
+ AnnotatorModel.AnnotationUsecase.SMART.getValue(),
+ isSerializedEntityDataEnabled));
+ for (AnnotatorModel.AnnotatedSpan span : annotations) {
+ final AnnotatorModel.ClassificationResult[] results = span.getClassification();
+ if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
+ continue;
}
- return mFallback.generateLinks(request);
- }
-
- int getMaxGenerateLinksTextLength() {
- return mSettings.getGenerateLinksMaxTextLength();
- }
-
- private Collection<String> getEntitiesForHints(Collection<String> hints) {
- final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE);
- final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE);
-
- // Use the default if there is no hint, or conflicting ones.
- final boolean useDefault = editable == notEditable;
- if (useDefault) {
- return mSettings.getEntityListDefault();
- } else if (editable) {
- return mSettings.getEntityListEditable();
- } else { // notEditable
- return mSettings.getEntityListNotEditable();
- }
- }
-
- void onSelectionEvent(SelectionEvent event) {
- TextClassifierEvent textClassifierEvent =
- SelectionEventConverter.toTextClassifierEvent(event);
- if (textClassifierEvent == null) {
- return;
- }
- onTextClassifierEvent(event.getSessionId(), textClassifierEvent);
- }
-
- void onTextClassifierEvent(
- @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
- mTextClassifierEventLogger.writeEvent(sessionId, event);
- if (mSettings.isUserLanguageProfileEnabled()) {
- mLanguageProfileAnalyzer.onTextClassifierEven(event);
- }
- }
-
- TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- final TextLanguage.Builder builder = new TextLanguage.Builder();
- final LangIdModel.LanguageResult[] langResults =
- getLangIdImpl().detectLanguages(request.getText().toString());
- for (int i = 0; i < langResults.length; i++) {
- builder.putLocale(
- ULocale.forLanguageTag(langResults[i].getLanguage()),
- langResults[i].getScore());
- }
- return builder.build();
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error detecting text language.", t);
- }
- return mFallback.detectLanguage(request);
- }
-
- ConversationActions suggestConversationActions(ConversationActions.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- if (mSettings.isUserLanguageProfileEnabled()) {
- // FIXME: Reuse the LangID result.
- ListenableFuture<Void> ignoredResult =
- mLanguageProfileUpdater.updateFromConversationActionsAsync(
- request,
- text -> detectLanguages(text, getLangIdThreshold()).getEntities());
- }
- try {
- ActionsSuggestionsModel actionsImpl = getActionsImpl();
- if (actionsImpl == null) {
- // Actions model is optional, fallback if it is not available.
- return mFallback.suggestConversationActions(request);
- }
- ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- request.getConversation(),
- text -> detectLanguages(text, getLangIdThreshold()).getEntities());
- if (nativeMessages.length == 0) {
- return mFallback.suggestConversationActions(request);
- }
- ActionsSuggestionsModel.Conversation nativeConversation =
- new ActionsSuggestionsModel.Conversation(nativeMessages);
-
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
- actionsImpl.suggestActionsWithIntents(
- nativeConversation,
- null,
- mContext,
- getResourceLocalesString(),
- getAnnotatorImpl(LocaleList.getDefault()));
- return createConversationActionResult(request, nativeSuggestions);
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error suggesting conversation actions.", t);
- }
- return mFallback.suggestConversationActions(request);
- }
-
- /**
- * Returns the {@link ConversationAction} result, with a non-null extras.
- *
- * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a
- * non-null component name is in the extras.
- */
- private ConversationActions createConversationActionResult(
- ConversationActions.Request request,
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
- Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
- List<ConversationAction> conversationActions = new ArrayList<>();
- for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
- String actionType = nativeSuggestion.getActionType();
- if (!expectedTypes.contains(actionType)) {
- continue;
- }
- LabeledIntent.Result labeledIntentResult =
- ActionsSuggestionsHelper.createLabeledIntentResult(
- mContext, mTemplateIntentFactory, nativeSuggestion);
- RemoteAction remoteAction = null;
- Bundle extras = new Bundle();
- if (labeledIntentResult != null) {
- remoteAction = labeledIntentResult.remoteAction;
- ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
- }
- ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
- ExtrasUtils.putEntitiesExtras(
- extras,
- TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
- conversationActions.add(
- new ConversationAction.Builder(actionType)
- .setConfidenceScore(nativeSuggestion.getScore())
- .setTextReply(nativeSuggestion.getResponseText())
- .setAction(remoteAction)
- .setExtras(extras)
- .build());
- }
- conversationActions =
- ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions);
- if (request.getMaxSuggestions() >= 0
- && conversationActions.size() > request.getMaxSuggestions()) {
- conversationActions = conversationActions.subList(0, request.getMaxSuggestions());
- }
- synchronized (mLock) {
- String resultId =
- ActionsSuggestionsHelper.createResultId(
- mContext,
- request.getConversation(),
- mActionModelInUse.getVersion(),
- mActionModelInUse.getSupportedLocales());
- return new ConversationActions(conversationActions, resultId);
- }
- }
-
- private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
- List<String> defaultActionTypes =
- request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION)
- ? mSettings.getNotificationConversationActionTypes()
- : mSettings.getInAppConversationActionTypes();
- return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
- }
-
- private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException {
- synchronized (mLock) {
- localeList = localeList == null ? LocaleList.getDefault() : localeList;
- final ModelFileManager.ModelFile bestModel =
- mAnnotatorModelFileManager.findBestModelFile(localeList);
- if (bestModel == null) {
- throw new FileNotFoundException(
- "No annotator model for " + localeList.toLanguageTags());
- }
- if (mAnnotatorImpl == null || !Objects.equals(mAnnotatorModelInUse, bestModel)) {
- TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd != null) {
- // The current annotator model may be still used by another thread / model.
- // Do not call close() here, and let the GC to clean it up when no one else
- // is using it.
- mAnnotatorImpl = new AnnotatorModel(pfd.getFd());
- mAnnotatorModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
- }
- }
- return mAnnotatorImpl;
- }
- }
-
- private LangIdModel getLangIdImpl() throws FileNotFoundException {
- synchronized (mLock) {
- final ModelFileManager.ModelFile bestModel =
- mLangIdModelFileManager.findBestModelFile(null);
- if (bestModel == null) {
- throw new FileNotFoundException("No LangID model is found");
- }
- if (mLangIdImpl == null || !Objects.equals(mLangIdModelInUse, bestModel)) {
- TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd != null) {
- mLangIdImpl = new LangIdModel(pfd.getFd());
- mLangIdModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
- }
- }
- return mLangIdImpl;
- }
- }
-
- @Nullable
- private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
- synchronized (mLock) {
- // TODO: Use LangID to determine the locale we should use here?
- final ModelFileManager.ModelFile bestModel =
- mActionsModelFileManager.findBestModelFile(LocaleList.getDefault());
- if (bestModel == null) {
- return null;
- }
- if (mActionsImpl == null || !Objects.equals(mActionModelInUse, bestModel)) {
- TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd == null) {
- TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath());
- return null;
- }
- ActionsModelParams params = mActionsModelParamsSupplier.get();
- mActionsImpl =
- new ActionsSuggestionsModel(
- pfd.getFd(), params.getSerializedPreconditions(bestModel));
- mActionModelInUse = bestModel;
- } finally {
- maybeCloseAndLogError(pfd);
- }
- }
- return mActionsImpl;
- }
- }
-
- private String createId(String text, int start, int end) {
- synchronized (mLock) {
- return ResultIdUtils.createId(
- mContext,
- text,
- start,
- end,
- mAnnotatorModelInUse.getVersion(),
- mAnnotatorModelInUse.getSupportedLocales());
- }
- }
-
- private static String concatenateLocales(@Nullable LocaleList locales) {
- return (locales == null) ? "" : locales.toLanguageTags();
- }
-
- private TextClassification createClassificationResult(
- AnnotatorModel.ClassificationResult[] classifications,
- String text,
- int start,
- int end,
- @Nullable Instant referenceTime) {
- final String classifiedText = text.substring(start, end);
- final TextClassification.Builder builder =
- new TextClassification.Builder().setText(classifiedText);
-
- final int typeCount = classifications.length;
- AnnotatorModel.ClassificationResult highestScoringResult =
- typeCount > 0 ? classifications[0] : null;
- for (int i = 0; i < typeCount; i++) {
- builder.setEntityType(
- classifications[i].getCollection(), classifications[i].getScore());
- if (classifications[i].getScore() > highestScoringResult.getScore()) {
- highestScoringResult = classifications[i];
- }
- }
- final Pair<Bundle, Bundle> languagesBundles = generateLanguageBundles(text, start, end);
- final Bundle textLanguagesBundle = languagesBundles.first;
- final Bundle foreignLanguageBundle = languagesBundles.second;
-
- boolean isPrimaryAction = true;
- final List<LabeledIntent> labeledIntents =
- mClassificationIntentFactory.create(
- mContext,
- classifiedText,
- foreignLanguageBundle != null,
- referenceTime,
- highestScoringResult);
- final LabeledIntent.TitleChooser titleChooser =
- (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity;
-
- ArrayList<Intent> actionIntents = new ArrayList<>();
- for (LabeledIntent labeledIntent : labeledIntents) {
- final LabeledIntent.Result result =
- labeledIntent.resolve(mContext, titleChooser, textLanguagesBundle);
- if (result == null) {
- continue;
- }
-
- final Intent intent = result.resolvedIntent;
- final RemoteAction action = result.remoteAction;
- if (isPrimaryAction) {
- // For O backwards compatibility, the first RemoteAction is also written to the
- // legacy API fields.
- builder.setIcon(action.getIcon().loadDrawable(mContext));
- builder.setLabel(action.getTitle().toString());
- builder.setIntent(intent);
- builder.setOnClickListener(
- createIntentOnClickListener(
- createPendingIntent(mContext, intent, labeledIntent.requestCode)));
- isPrimaryAction = false;
- }
- builder.addAction(action);
- actionIntents.add(intent);
+ final Map<String, Float> entityScores = new ArrayMap<>();
+ for (int i = 0; i < results.length; i++) {
+ entityScores.put(results[i].getCollection(), results[i].getScore());
}
Bundle extras = new Bundle();
- ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageBundle);
- if (actionIntents.stream().anyMatch(Objects::nonNull)) {
- ExtrasUtils.putActionsIntents(extras, actionIntents);
+ if (isSerializedEntityDataEnabled) {
+ ExtrasUtils.putEntities(extras, results);
}
- ExtrasUtils.putEntities(extras, classifications);
- builder.setExtras(extras);
- return builder.setId(createId(text, start, end)).build();
+ builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
+ }
+ final TextLinks links = builder.build();
+ final long endTimeMs = System.currentTimeMillis();
+ final String callingPackageName =
+ request.getCallingPackageName() == null
+ ? context.getPackageName() // local (in process) TC.
+ : request.getCallingPackageName();
+ generateLinksLogger.logGenerateLinks(
+ request.getText(), links, callingPackageName, endTimeMs - startTimeMs);
+ return links;
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error getting links info.", t);
}
+ return fallback.generateLinks(request);
+ }
- private static OnClickListener createIntentOnClickListener(
- @NonNull final PendingIntent intent) {
- Preconditions.checkNotNull(intent);
- return v -> {
- try {
- intent.send();
- } catch (PendingIntent.CanceledException e) {
- TcLog.e(TAG, "Error sending PendingIntent", e);
- }
- };
+ int getMaxGenerateLinksTextLength() {
+ return settings.getGenerateLinksMaxTextLength();
+ }
+
+ private Collection<String> getEntitiesForHints(Collection<String> hints) {
+ final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE);
+ final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE);
+
+ // Use the default if there is no hint, or conflicting ones.
+ final boolean useDefault = editable == notEditable;
+ if (useDefault) {
+ return settings.getEntityListDefault();
+ } else if (editable) {
+ return settings.getEntityListEditable();
+ } else { // notEditable
+ return settings.getEntityListNotEditable();
}
+ }
- /**
- * Returns a bundle pair with language detection information for extras.
- *
- * <p>Pair.first = textLanguagesBundle - A bundle containing information about all detected
- * languages in the text. May be null if language detection fails or is disabled. This is
- * typically expected to be added to a textClassifier generated remote action intent. See {@link
- * ExtrasUtils#putTextLanguagesExtra(Bundle, Bundle)}. See {@link
- * ExtrasUtils#getTopLanguage(Intent)}.
- *
- * <p>Pair.second = foreignLanguageBundle - A bundle with the language and confidence score if
- * the system finds the text to be in a foreign language. Otherwise is null. See {@link
- * TextClassification.Builder#setForeignLanguageExtra(Bundle)}.
- *
- * @param context the context of the text to detect languages for
- * @param start the start index of the text
- * @param end the end index of the text
- */
- // TODO: Revisit this algorithm.
- // TODO: Consider making this public API.
- private Pair<Bundle, Bundle> generateLanguageBundles(String context, int start, int end) {
- if (!mSettings.isTranslateInClassificationEnabled()) {
+ void onSelectionEvent(SelectionEvent event) {
+ TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event);
+ if (textClassifierEvent == null) {
+ return;
+ }
+ onTextClassifierEvent(event.getSessionId(), textClassifierEvent);
+ }
+
+ void onTextClassifierEvent(
+ @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ textClassifierEventLogger.writeEvent(sessionId, event);
+ if (settings.isUserLanguageProfileEnabled()) {
+ languageProfileAnalyzer.onTextClassifierEven(event);
+ }
+ }
+
+ TextLanguage detectLanguage(TextLanguage.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ final TextLanguage.Builder builder = new TextLanguage.Builder();
+ final LangIdModel.LanguageResult[] langResults =
+ getLangIdImpl().detectLanguages(request.getText().toString());
+ for (int i = 0; i < langResults.length; i++) {
+ builder.putLocale(
+ ULocale.forLanguageTag(langResults[i].getLanguage()), langResults[i].getScore());
+ }
+ return builder.build();
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error detecting text language.", t);
+ }
+ return fallback.detectLanguage(request);
+ }
+
+ ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ if (settings.isUserLanguageProfileEnabled()) {
+ // TODO(tonymak): Reuse the LangID result.
+ languageProfileUpdater.updateFromConversationActionsAsync(
+ request, text -> detectLanguages(text, getLangIdThreshold()).getEntities());
+ }
+ try {
+ ActionsSuggestionsModel actionsImpl = getActionsImpl();
+ if (actionsImpl == null) {
+ // Actions model is optional, fallback if it is not available.
+ return fallback.suggestConversationActions(request);
+ }
+ ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ request.getConversation(),
+ text -> detectLanguages(text, getLangIdThreshold()).getEntities());
+ if (nativeMessages.length == 0) {
+ return fallback.suggestConversationActions(request);
+ }
+ ActionsSuggestionsModel.Conversation nativeConversation =
+ new ActionsSuggestionsModel.Conversation(nativeMessages);
+
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
+ actionsImpl.suggestActionsWithIntents(
+ nativeConversation,
+ null,
+ context,
+ getResourceLocalesString(),
+ getAnnotatorImpl(LocaleList.getDefault()));
+ return createConversationActionResult(request, nativeSuggestions);
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error suggesting conversation actions.", t);
+ }
+ return fallback.suggestConversationActions(request);
+ }
+
+ /**
+ * Returns the {@link ConversationAction} result, with a non-null extras.
+ *
+ * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a
+ * non-null component name is in the extras.
+ */
+ private ConversationActions createConversationActionResult(
+ ConversationActions.Request request,
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
+ Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
+ List<ConversationAction> conversationActions = new ArrayList<>();
+ for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
+ String actionType = nativeSuggestion.getActionType();
+ if (!expectedTypes.contains(actionType)) {
+ continue;
+ }
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ context, templateIntentFactory, nativeSuggestion);
+ RemoteAction remoteAction = null;
+ Bundle extras = new Bundle();
+ if (labeledIntentResult != null) {
+ remoteAction = labeledIntentResult.remoteAction;
+ ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
+ }
+ ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
+ ExtrasUtils.putEntitiesExtras(
+ extras, TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
+ conversationActions.add(
+ new ConversationAction.Builder(actionType)
+ .setConfidenceScore(nativeSuggestion.getScore())
+ .setTextReply(nativeSuggestion.getResponseText())
+ .setAction(remoteAction)
+ .setExtras(extras)
+ .build());
+ }
+ conversationActions = ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions);
+ if (request.getMaxSuggestions() >= 0
+ && conversationActions.size() > request.getMaxSuggestions()) {
+ conversationActions = conversationActions.subList(0, request.getMaxSuggestions());
+ }
+ synchronized (lock) {
+ String resultId =
+ ActionsSuggestionsHelper.createResultId(
+ context,
+ request.getConversation(),
+ actionModelInUse.getVersion(),
+ actionModelInUse.getSupportedLocales());
+ return new ConversationActions(conversationActions, resultId);
+ }
+ }
+
+ private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
+ List<String> defaultActionTypes =
+ request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION)
+ ? settings.getNotificationConversationActionTypes()
+ : settings.getInAppConversationActionTypes();
+ return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
+ }
+
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException {
+ synchronized (lock) {
+ localeList = localeList == null ? LocaleList.getDefault() : localeList;
+ final ModelFileManager.ModelFile bestModel =
+ annotatorModelFileManager.findBestModelFile(localeList);
+ if (bestModel == null) {
+ throw new FileNotFoundException("No annotator model for " + localeList.toLanguageTags());
+ }
+ if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd != null) {
+ // The current annotator model may be still used by another thread / model.
+ // Do not call close() here, and let the GC to clean it up when no one else
+ // is using it.
+ annotatorImpl = new AnnotatorModel(pfd.getFd());
+ annotatorModelInUse = bestModel;
+ }
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return annotatorImpl;
+ }
+ }
+
+ private LangIdModel getLangIdImpl() throws FileNotFoundException {
+ synchronized (lock) {
+ final ModelFileManager.ModelFile bestModel = langIdModelFileManager.findBestModelFile(null);
+ if (bestModel == null) {
+ throw new FileNotFoundException("No LangID model is found");
+ }
+ if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd != null) {
+ langIdImpl = new LangIdModel(pfd.getFd());
+ langIdModelInUse = bestModel;
+ }
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return langIdImpl;
+ }
+ }
+
+ @Nullable
+ private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
+ synchronized (lock) {
+ // TODO: Use LangID to determine the locale we should use here?
+ final ModelFileManager.ModelFile bestModel =
+ actionsModelFileManager.findBestModelFile(LocaleList.getDefault());
+ if (bestModel == null) {
+ return null;
+ }
+ if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd == null) {
+ TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath());
return null;
+ }
+ ActionsModelParams params = actionsModelParamsSupplier.get();
+ actionsImpl =
+ new ActionsSuggestionsModel(
+ pfd.getFd(), params.getSerializedPreconditions(bestModel));
+ actionModelInUse = bestModel;
+ } finally {
+ maybeCloseAndLogError(pfd);
}
- try {
- final float threshold = getLangIdThreshold();
- if (threshold < 0 || threshold > 1) {
- TcLog.w(TAG, "[detectForeignLanguage] unexpected threshold is found: " + threshold);
- return Pair.create(null, null);
- }
+ }
+ return actionsImpl;
+ }
+ }
- final EntityConfidence languageScores = detectLanguages(context, start, end);
- if (languageScores.getEntities().isEmpty()) {
- return Pair.create(null, null);
- }
+ private String createId(String text, int start, int end) {
+ synchronized (lock) {
+ return ResultIdUtils.createId(
+ context,
+ text,
+ start,
+ end,
+ annotatorModelInUse.getVersion(),
+ annotatorModelInUse.getSupportedLocales());
+ }
+ }
- final Bundle textLanguagesBundle = new Bundle();
- ExtrasUtils.putTopLanguageScores(textLanguagesBundle, languageScores);
+ private static String concatenateLocales(@Nullable LocaleList locales) {
+ return (locales == null) ? "" : locales.toLanguageTags();
+ }
- final String language = languageScores.getEntities().get(0);
- final float score = languageScores.getConfidenceScore(language);
- if (score < threshold) {
- return Pair.create(textLanguagesBundle, null);
- }
+ private TextClassification createClassificationResult(
+ AnnotatorModel.ClassificationResult[] classifications,
+ String text,
+ int start,
+ int end,
+ @Nullable Instant referenceTime) {
+ final String classifiedText = text.substring(start, end);
+ final TextClassification.Builder builder =
+ new TextClassification.Builder().setText(classifiedText);
- TcLog.v(TAG, String.format(Locale.US, "Language detected: <%s:%.2f>", language, score));
- if (mSettings.isUserLanguageProfileEnabled()) {
- if (!mLanguageProfileAnalyzer.shouldShowTranslation(language)) {
- return Pair.create(textLanguagesBundle, null);
- }
- } else {
- final Locale detected = new Locale(language);
- final LocaleList deviceLocales = LocaleList.getDefault();
- final int size = deviceLocales.size();
- for (int i = 0; i < size; i++) {
- if (deviceLocales.get(i).getLanguage().equals(detected.getLanguage())) {
- return Pair.create(textLanguagesBundle, null);
- }
- }
- }
- final Bundle foreignLanguageBundle =
- ExtrasUtils.createForeignLanguageExtra(
- language, score, getLangIdImpl().getVersion());
- return Pair.create(textLanguagesBundle, foreignLanguageBundle);
- } catch (Throwable t) {
- TcLog.e(TAG, "Error generating language bundles.", t);
- }
+ final int typeCount = classifications.length;
+ AnnotatorModel.ClassificationResult highestScoringResult =
+ typeCount > 0 ? classifications[0] : null;
+ for (int i = 0; i < typeCount; i++) {
+ builder.setEntityType(classifications[i].getCollection(), classifications[i].getScore());
+ if (classifications[i].getScore() > highestScoringResult.getScore()) {
+ highestScoringResult = classifications[i];
+ }
+ }
+ final Pair<Bundle, Bundle> languagesBundles = generateLanguageBundles(text, start, end);
+ final Bundle textLanguagesBundle = languagesBundles.first;
+ final Bundle foreignLanguageBundle = languagesBundles.second;
+
+ boolean isPrimaryAction = true;
+ final List<LabeledIntent> labeledIntents =
+ classificationIntentFactory.create(
+ context,
+ classifiedText,
+ foreignLanguageBundle != null,
+ referenceTime,
+ highestScoringResult);
+ final LabeledIntent.TitleChooser titleChooser =
+ (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity;
+
+ ArrayList<Intent> actionIntents = new ArrayList<>();
+ for (LabeledIntent labeledIntent : labeledIntents) {
+ final LabeledIntent.Result result =
+ labeledIntent.resolve(context, titleChooser, textLanguagesBundle);
+ if (result == null) {
+ continue;
+ }
+
+ final Intent intent = result.resolvedIntent;
+ final RemoteAction action = result.remoteAction;
+ if (isPrimaryAction) {
+ // For O backwards compatibility, the first RemoteAction is also written to the
+ // legacy API fields.
+ builder.setIcon(action.getIcon().loadDrawable(context));
+ builder.setLabel(action.getTitle().toString());
+ builder.setIntent(intent);
+ builder.setOnClickListener(
+ createIntentOnClickListener(
+ createPendingIntent(context, intent, labeledIntent.requestCode)));
+ isPrimaryAction = false;
+ }
+ builder.addAction(action);
+ actionIntents.add(intent);
+ }
+ Bundle extras = new Bundle();
+ ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageBundle);
+ if (actionIntents.stream().anyMatch(Objects::nonNull)) {
+ ExtrasUtils.putActionsIntents(extras, actionIntents);
+ }
+ ExtrasUtils.putEntities(extras, classifications);
+ builder.setExtras(extras);
+ return builder.setId(createId(text, start, end)).build();
+ }
+
+ private static OnClickListener createIntentOnClickListener(final PendingIntent intent) {
+ Preconditions.checkNotNull(intent);
+ return v -> {
+ try {
+ intent.send();
+ } catch (PendingIntent.CanceledException e) {
+ TcLog.e(TAG, "Error sending PendingIntent", e);
+ }
+ };
+ }
+
+ /**
+ * Returns a bundle pair with language detection information for extras.
+ *
+ * <p>Pair.first = textLanguagesBundle - A bundle containing information about all detected
+ * languages in the text. May be null if language detection fails or is disabled. This is
+ * typically expected to be added to a textClassifier generated remote action intent. See {@link
+ * ExtrasUtils#putTextLanguagesExtra(Bundle, Bundle)}. See {@link
+ * ExtrasUtils#getTopLanguage(Intent)}.
+ *
+ * <p>Pair.second = foreignLanguageBundle - A bundle with the language and confidence score if the
+ * system finds the text to be in a foreign language. Otherwise is null. See {@link
+ * TextClassification.Builder#setForeignLanguageExtra(Bundle)}.
+ *
+ * @param context the context of the text to detect languages for
+ * @param start the start index of the text
+ * @param end the end index of the text
+ */
+ // TODO: Revisit this algorithm.
+ // TODO: Consider making this public API.
+ private Pair<Bundle, Bundle> generateLanguageBundles(String context, int start, int end) {
+ if (!settings.isTranslateInClassificationEnabled()) {
+ return null;
+ }
+ try {
+ final float threshold = getLangIdThreshold();
+ if (threshold < 0 || threshold > 1) {
+ TcLog.w(TAG, "[detectForeignLanguage] unexpected threshold is found: " + threshold);
return Pair.create(null, null);
+ }
+
+ final EntityConfidence languageScores = detectLanguages(context, start, end);
+ if (languageScores.getEntities().isEmpty()) {
+ return Pair.create(null, null);
+ }
+
+ final Bundle textLanguagesBundle = new Bundle();
+ ExtrasUtils.putTopLanguageScores(textLanguagesBundle, languageScores);
+
+ final String language = languageScores.getEntities().get(0);
+ final float score = languageScores.getConfidenceScore(language);
+ if (score < threshold) {
+ return Pair.create(textLanguagesBundle, null);
+ }
+
+ TcLog.v(TAG, String.format(Locale.US, "Language detected: <%s:%.2f>", language, score));
+ if (settings.isUserLanguageProfileEnabled()) {
+ if (!languageProfileAnalyzer.shouldShowTranslation(language)) {
+ return Pair.create(textLanguagesBundle, null);
+ }
+ } else {
+ final Locale detected = new Locale(language);
+ final LocaleList deviceLocales = LocaleList.getDefault();
+ final int size = deviceLocales.size();
+ for (int i = 0; i < size; i++) {
+ if (deviceLocales.get(i).getLanguage().equals(detected.getLanguage())) {
+ return Pair.create(textLanguagesBundle, null);
+ }
+ }
+ }
+ final Bundle foreignLanguageBundle =
+ ExtrasUtils.createForeignLanguageExtra(language, score, getLangIdImpl().getVersion());
+ return Pair.create(textLanguagesBundle, foreignLanguageBundle);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Error generating language bundles.", t);
+ }
+ return Pair.create(null, null);
+ }
+
+ /**
+ * Detect the language of a piece of text by taking surrounding text into consideration.
+ *
+ * @param text text providing context for the text for which its language is to be detected
+ * @param start the start index of the text to detect its language
+ * @param end the end index of the text to detect its language
+ */
+ // TODO: Revisit this algorithm.
+ private EntityConfidence detectLanguages(String text, int start, int end) {
+ Preconditions.checkArgument(start >= 0);
+ Preconditions.checkArgument(end <= text.length());
+ Preconditions.checkArgument(start <= end);
+
+ final float[] langIdContextSettings = settings.getLangIdContextSettings();
+ // The minimum size of text to prefer for detection.
+ final int minimumTextSize = (int) langIdContextSettings[0];
+ // For reducing the score when text is less than the preferred size.
+ final float penalizeRatio = langIdContextSettings[1];
+ // Original detection score to surrounding text detection score ratios.
+ final float subjectTextScoreRatio = langIdContextSettings[2];
+ final float moreTextScoreRatio = 1f - subjectTextScoreRatio;
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.US,
+ "LangIdContextSettings: "
+ + "minimumTextSize=%d, penalizeRatio=%.2f, "
+ + "subjectTextScoreRatio=%.2f, moreTextScoreRatio=%.2f",
+ minimumTextSize,
+ penalizeRatio,
+ subjectTextScoreRatio,
+ moreTextScoreRatio));
+
+ if (end - start < minimumTextSize && penalizeRatio <= 0) {
+ return EntityConfidence.EMPTY;
}
- /**
- * Detect the language of a piece of text by taking surrounding text into consideration.
- *
- * @param text text providing context for the text for which its language is to be detected
- * @param start the start index of the text to detect its language
- * @param end the end index of the text to detect its language
- */
- // TODO: Revisit this algorithm.
- private EntityConfidence detectLanguages(String text, int start, int end) {
- Preconditions.checkArgument(start >= 0);
- Preconditions.checkArgument(end <= text.length());
- Preconditions.checkArgument(start <= end);
+ final String subject = text.substring(start, end);
+ final EntityConfidence scores = detectLanguages(subject, /* threshold= */ 0f);
- final float[] langIdContextSettings = mSettings.getLangIdContextSettings();
- // The minimum size of text to prefer for detection.
- final int minimumTextSize = (int) langIdContextSettings[0];
- // For reducing the score when text is less than the preferred size.
- final float penalizeRatio = langIdContextSettings[1];
- // Original detection score to surrounding text detection score ratios.
- final float subjectTextScoreRatio = langIdContextSettings[2];
- final float moreTextScoreRatio = 1f - subjectTextScoreRatio;
- TcLog.v(
- TAG,
- String.format(
- Locale.US,
- "LangIdContextSettings: "
- + "minimumTextSize=%d, penalizeRatio=%.2f, "
- + "subjectTextScoreRatio=%.2f, moreTextScoreRatio=%.2f",
- minimumTextSize,
- penalizeRatio,
- subjectTextScoreRatio,
- moreTextScoreRatio));
-
- if (end - start < minimumTextSize && penalizeRatio <= 0) {
- return EntityConfidence.EMPTY;
- }
-
- final String subject = text.substring(start, end);
- final EntityConfidence scores = detectLanguages(subject, /* threshold= */ 0f);
-
- if (subject.length() >= minimumTextSize
- || subject.length() == text.length()
- || subjectTextScoreRatio * penalizeRatio >= 1) {
- return scores;
- }
-
- final EntityConfidence moreTextScores;
- if (moreTextScoreRatio >= 0) {
- // Attempt to grow the detection text to be at least minimumTextSize long.
- final String moreText = StringUtils.getSubString(text, start, end, minimumTextSize);
- moreTextScores = detectLanguages(moreText, /* threshold= */ 0f);
- } else {
- moreTextScores = EntityConfidence.EMPTY;
- }
-
- // Combine the original detection scores with the those returned after including more text.
- final Map<String, Float> newScores = new ArrayMap<>();
- final Set<String> languages = new ArraySet<>();
- languages.addAll(scores.getEntities());
- languages.addAll(moreTextScores.getEntities());
- for (String language : languages) {
- final float score =
- (subjectTextScoreRatio * scores.getConfidenceScore(language)
- + moreTextScoreRatio
- * moreTextScores.getConfidenceScore(language))
- * penalizeRatio;
- newScores.put(language, score);
- }
- return new EntityConfidence(newScores);
+ if (subject.length() >= minimumTextSize
+ || subject.length() == text.length()
+ || subjectTextScoreRatio * penalizeRatio >= 1) {
+ return scores;
}
- /**
- * Detects languages for the specified text. Only returns languages with score that is higher
- * than or equal to the specified threshold.
- */
- private EntityConfidence detectLanguages(CharSequence text, float threshold) {
- final LangIdModel langId;
- try {
- langId = getLangIdImpl();
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "detectLanguages: Failed to call getLangIdImpl ", e);
- return EntityConfidence.EMPTY;
- }
- final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString());
- final Map<String, Float> languagesMap = new ArrayMap<>();
- for (LangIdModel.LanguageResult langResult : langResults) {
- if (langResult.getScore() >= threshold) {
- languagesMap.put(langResult.getLanguage(), langResult.getScore());
- }
- }
- return new EntityConfidence(languagesMap);
+ final EntityConfidence moreTextScores;
+ if (moreTextScoreRatio >= 0) {
+ // Attempt to grow the detection text to be at least minimumTextSize long.
+ final String moreText = StringUtils.getSubString(text, start, end, minimumTextSize);
+ moreTextScores = detectLanguages(moreText, /* threshold= */ 0f);
+ } else {
+ moreTextScores = EntityConfidence.EMPTY;
}
- private float getLangIdThreshold() {
- try {
- return mSettings.getLangIdThresholdOverride() >= 0
- ? mSettings.getLangIdThresholdOverride()
- : getLangIdImpl().getLangIdThreshold();
- } catch (FileNotFoundException e) {
- final float defaultThreshold = 0.5f;
- TcLog.v(TAG, "Using default foreign language threshold: " + defaultThreshold);
- return defaultThreshold;
- }
+ // Combine the original detection scores with the those returned after including more text.
+ final Map<String, Float> newScores = new ArrayMap<>();
+ final Set<String> languages = new ArraySet<>();
+ languages.addAll(scores.getEntities());
+ languages.addAll(moreTextScores.getEntities());
+ for (String language : languages) {
+ final float score =
+ (subjectTextScoreRatio * scores.getConfidenceScore(language)
+ + moreTextScoreRatio * moreTextScores.getConfidenceScore(language))
+ * penalizeRatio;
+ newScores.put(language, score);
+ }
+ return new EntityConfidence(newScores);
+ }
+
+ /**
+ * Detects languages for the specified text. Only returns languages with score that is higher than
+ * or equal to the specified threshold.
+ */
+ private EntityConfidence detectLanguages(CharSequence text, float threshold) {
+ final LangIdModel langId;
+ try {
+ langId = getLangIdImpl();
+ } catch (FileNotFoundException e) {
+ TcLog.e(TAG, "detectLanguages: Failed to call getLangIdImpl ", e);
+ return EntityConfidence.EMPTY;
+ }
+ final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString());
+ final Map<String, Float> languagesMap = new ArrayMap<>();
+ for (LangIdModel.LanguageResult langResult : langResults) {
+ if (langResult.getScore() >= threshold) {
+ languagesMap.put(langResult.getLanguage(), langResult.getScore());
+ }
+ }
+ return new EntityConfidence(languagesMap);
+ }
+
+ private float getLangIdThreshold() {
+ try {
+ return settings.getLangIdThresholdOverride() >= 0
+ ? settings.getLangIdThresholdOverride()
+ : getLangIdImpl().getLangIdThreshold();
+ } catch (FileNotFoundException e) {
+ final float defaultThreshold = 0.5f;
+ TcLog.v(TAG, "Using default foreign language threshold: " + defaultThreshold);
+ return defaultThreshold;
+ }
+ }
+
+ void dump(IndentingPrintWriter printWriter) {
+ synchronized (lock) {
+ printWriter.println("TextClassifierImpl:");
+ printWriter.increaseIndent();
+ printWriter.println("Annotator model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : annotatorModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("LangID model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : langIdModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("Actions model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : actionsModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.printPair("mFallback", fallback);
+ printWriter.decreaseIndent();
+ printWriter.println();
+ settings.dump(printWriter);
+ if (settings.isUserLanguageProfileEnabled()) {
+ printWriter.println();
+ languageProfileUpdater.dump(printWriter);
+ printWriter.println();
+ languageProfileAnalyzer.dump(printWriter);
+ }
+ }
+ }
+
+ /** Returns the locales string for the current resources configuration. */
+ private String getResourceLocalesString() {
+ try {
+ return context.getResources().getConfiguration().getLocales().toLanguageTags();
+ } catch (NullPointerException e) {
+
+ // NPE is unexpected. Erring on the side of caution.
+ return LocaleList.getDefault().toLanguageTags();
+ }
+ }
+
+ /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
}
- void dump(IndentingPrintWriter printWriter) {
- synchronized (mLock) {
- printWriter.println("TextClassifierImpl:");
- printWriter.increaseIndent();
- printWriter.println("Annotator model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile :
- mAnnotatorModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.println("LangID model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : mLangIdModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.println("Actions model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : mActionsModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.printPair("mFallback", mFallback);
- printWriter.decreaseIndent();
- printWriter.println();
- mSettings.dump(printWriter);
- if (mSettings.isUserLanguageProfileEnabled()) {
- printWriter.println();
- mLanguageProfileUpdater.dump(printWriter);
- printWriter.println();
- mLanguageProfileAnalyzer.dump(printWriter);
- }
- }
+ try {
+ fd.close();
+ } catch (IOException e) {
+ TcLog.e(TAG, "Error closing file.", e);
}
+ }
- /** Returns the locales string for the current resources configuration. */
- private String getResourceLocalesString() {
- try {
- return mContext.getResources().getConfiguration().getLocales().toLanguageTags();
- } catch (NullPointerException e) {
-
- // NPE is unexpected. Erring on the side of caution.
- return LocaleList.getDefault().toLanguageTags();
- }
+ private static void checkMainThread() {
+ if (Looper.myLooper() == Looper.getMainLooper()) {
+ TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
}
+ }
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
-
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
- }
-
- private static void checkMainThread() {
- if (Looper.myLooper() == Looper.getMainLooper()) {
- TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
- }
- }
-
- private static PendingIntent createPendingIntent(
- @NonNull final Context context, @NonNull final Intent intent, int requestCode) {
- return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
- }
+ private static PendingIntent createPendingIntent(
+ final Context context, final Intent intent, int requestCode) {
+ return PendingIntent.getActivity(
+ context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
index 180bc96..8704644 100644
--- a/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,44 +13,41 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.content.Context;
import android.content.Intent;
-
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.R;
-
import com.google.android.textclassifier.AnnotatorModel;
-
import java.time.Instant;
import java.util.List;
+import javax.annotation.Nullable;
-/** @hide */
+/** Generates intents from classification results. */
public interface ClassificationIntentFactory {
- /** Return a list of LabeledIntent from the classification result. */
- List<LabeledIntent> create(
- Context context,
- String text,
- boolean foreignText,
- @Nullable Instant referenceTime,
- @Nullable AnnotatorModel.ClassificationResult classification);
+ /** Return a list of LabeledIntent from the classification result. */
+ List<LabeledIntent> create(
+ Context context,
+ String text,
+ boolean foreignText,
+ @Nullable Instant referenceTime,
+ @Nullable AnnotatorModel.ClassificationResult classification);
- /** Inserts translate action to the list if it is a foreign text. */
- static void insertTranslateAction(List<LabeledIntent> actions, Context context, String text) {
- actions.add(
- new LabeledIntent(
- context.getString(R.string.translate),
- /* titleWithEntity */ null,
- context.getString(R.string.translate_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_TRANSLATE)
- // TODO: Probably better to introduce a "translate" scheme instead
- // of
- // using EXTRA_TEXT.
- .putExtra(Intent.EXTRA_TEXT, text),
- text.hashCode()));
- }
+ /** Inserts translate action to the list if it is a foreign text. */
+ static void insertTranslateAction(List<LabeledIntent> actions, Context context, String text) {
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.translate),
+ /* titleWithEntity */ null,
+ context.getString(R.string.translate_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_TRANSLATE)
+ // TODO: Probably better to introduce a "translate" scheme instead
+ // of
+ // using EXTRA_TEXT.
+ .putExtra(Intent.EXTRA_TEXT, text),
+ text.hashCode()));
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/LabeledIntent.java b/java/src/com/android/textclassifier/intent/LabeledIntent.java
index e51aa82..e678291 100644
--- a/java/src/com/android/textclassifier/intent/LabeledIntent.java
+++ b/java/src/com/android/textclassifier/intent/LabeledIntent.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.app.PendingIntent;
@@ -26,195 +27,183 @@
import android.os.Bundle;
import android.text.TextUtils;
import android.view.textclassifier.TextClassifier;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.ExtrasUtils;
import com.android.textclassifier.R;
import com.android.textclassifier.TcLog;
+import com.google.common.base.Preconditions;
+import javax.annotation.Nullable;
/**
* Helper class to store the information from which RemoteActions are built.
*
* @hide
*/
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public final class LabeledIntent {
- private static final String TAG = "LabeledIntent";
- public static final int DEFAULT_REQUEST_CODE = 0;
- private static final TitleChooser DEFAULT_TITLE_CHOOSER =
- (labeledIntent, resolveInfo) -> {
- if (!TextUtils.isEmpty(labeledIntent.titleWithEntity)) {
- return labeledIntent.titleWithEntity;
- }
- return labeledIntent.titleWithoutEntity;
- };
-
- @Nullable public final String titleWithoutEntity;
- @Nullable public final String titleWithEntity;
- public final String description;
- @Nullable public final String descriptionWithAppName;
- // Do not update this intent.
- public final Intent intent;
- public final int requestCode;
-
- /**
- * Initializes a LabeledIntent.
- *
- * <p>NOTE: {@code requestCode} is required to not be {@link #DEFAULT_REQUEST_CODE} if
- * distinguishing info (e.g. the classified text) is represented in intent extras only. In such
- * circumstances, the request code should represent the distinguishing info (e.g. by generating
- * a hashcode) so that the generated PendingIntent is (somewhat) unique. To be correct, the
- * PendingIntent should be definitely unique but we try a best effort approach that avoids
- * spamming the system with PendingIntents.
- */
- // TODO: Fix the issue mentioned above so the behaviour is correct.
- public LabeledIntent(
- @Nullable String titleWithoutEntity,
- @Nullable String titleWithEntity,
- String description,
- @Nullable String descriptionWithAppName,
- Intent intent,
- int requestCode) {
- if (TextUtils.isEmpty(titleWithEntity) && TextUtils.isEmpty(titleWithoutEntity)) {
- throw new IllegalArgumentException(
- "titleWithEntity and titleWithoutEntity should not be both null");
+ private static final String TAG = "LabeledIntent";
+ public static final int DEFAULT_REQUEST_CODE = 0;
+ private static final TitleChooser DEFAULT_TITLE_CHOOSER =
+ (labeledIntent, resolveInfo) -> {
+ if (!TextUtils.isEmpty(labeledIntent.titleWithEntity)) {
+ return labeledIntent.titleWithEntity;
}
- this.titleWithoutEntity = titleWithoutEntity;
- this.titleWithEntity = titleWithEntity;
- this.description = Preconditions.checkNotNull(description);
- this.descriptionWithAppName = descriptionWithAppName;
- this.intent = Preconditions.checkNotNull(intent);
- this.requestCode = requestCode;
- }
+ return labeledIntent.titleWithoutEntity;
+ };
+ @Nullable public final String titleWithoutEntity;
+ @Nullable public final String titleWithEntity;
+ public final String description;
+ @Nullable public final String descriptionWithAppName;
+ // Do not update this intent.
+ public final Intent intent;
+ public final int requestCode;
+
+ /**
+ * Initializes a LabeledIntent.
+ *
+ * <p>NOTE: {@code requestCode} is required to not be {@link #DEFAULT_REQUEST_CODE} if
+ * distinguishing info (e.g. the classified text) is represented in intent extras only. In such
+ * circumstances, the request code should represent the distinguishing info (e.g. by generating a
+ * hashcode) so that the generated PendingIntent is (somewhat) unique. To be correct, the
+ * PendingIntent should be definitely unique but we try a best effort approach that avoids
+ * spamming the system with PendingIntents.
+ */
+ // TODO: Fix the issue mentioned above so the behaviour is correct.
+ public LabeledIntent(
+ @Nullable String titleWithoutEntity,
+ @Nullable String titleWithEntity,
+ String description,
+ @Nullable String descriptionWithAppName,
+ Intent intent,
+ int requestCode) {
+ if (TextUtils.isEmpty(titleWithEntity) && TextUtils.isEmpty(titleWithoutEntity)) {
+ throw new IllegalArgumentException(
+ "titleWithEntity and titleWithoutEntity should not be both null");
+ }
+ this.titleWithoutEntity = titleWithoutEntity;
+ this.titleWithEntity = titleWithEntity;
+ this.description = Preconditions.checkNotNull(description);
+ this.descriptionWithAppName = descriptionWithAppName;
+ this.intent = Preconditions.checkNotNull(intent);
+ this.requestCode = requestCode;
+ }
+
+ /**
+ * Return the resolved result.
+ *
+ * @param context the context to resolve the result's intent and action
+ * @param titleChooser for choosing an action title
+ * @param textLanguagesBundle containing language detection information
+ */
+ @Nullable
+ public Result resolve(
+ Context context, @Nullable TitleChooser titleChooser, @Nullable Bundle textLanguagesBundle) {
+ final PackageManager pm = context.getPackageManager();
+ final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0);
+
+ if (resolveInfo == null || resolveInfo.activityInfo == null) {
+ TcLog.w(TAG, "resolveInfo or activityInfo is null");
+ return null;
+ }
+ final String packageName = resolveInfo.activityInfo.packageName;
+ final String className = resolveInfo.activityInfo.name;
+ if (packageName == null || className == null) {
+ TcLog.w(TAG, "packageName or className is null");
+ return null;
+ }
+ Intent resolvedIntent = new Intent(intent);
+ resolvedIntent.putExtra(
+ TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER, getFromTextClassifierExtra(textLanguagesBundle));
+ boolean shouldShowIcon = false;
+ Icon icon = null;
+ if (!"android".equals(packageName)) {
+ // We only set the component name when the package name is not resolved to "android"
+ // to workaround a bug that explicit intent with component name == ResolverActivity
+ // can't be launched on keyguard.
+ resolvedIntent.setComponent(new ComponentName(packageName, className));
+ if (resolveInfo.activityInfo.getIconResource() != 0) {
+ icon = Icon.createWithResource(packageName, resolveInfo.activityInfo.getIconResource());
+ shouldShowIcon = true;
+ }
+ }
+ if (icon == null) {
+ // RemoteAction requires that there be an icon.
+ icon = Icon.createWithResource(context, R.drawable.app_icon);
+ }
+ final PendingIntent pendingIntent = createPendingIntent(context, resolvedIntent, requestCode);
+ titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
+ CharSequence title = titleChooser.chooseTitle(this, resolveInfo);
+ if (TextUtils.isEmpty(title)) {
+ TcLog.w(TAG, "Custom titleChooser return null, fallback to the default titleChooser");
+ title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
+ }
+ final RemoteAction action =
+ new RemoteAction(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
+ action.setShouldShowIcon(shouldShowIcon);
+ return new Result(resolvedIntent, action);
+ }
+
+ private String resolveDescription(ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (!TextUtils.isEmpty(descriptionWithAppName)) {
+ // Example string format of descriptionWithAppName: "Use %1$s to open map".
+ String applicationName = getApplicationName(resolveInfo, packageManager);
+ if (!TextUtils.isEmpty(applicationName)) {
+ return String.format(descriptionWithAppName, applicationName);
+ }
+ }
+ return description;
+ }
+
+ @Nullable
+ private String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (resolveInfo.activityInfo == null) {
+ return null;
+ }
+ if ("android".equals(resolveInfo.activityInfo.packageName)) {
+ return null;
+ }
+ if (resolveInfo.activityInfo.applicationInfo == null) {
+ return null;
+ }
+ return (String) packageManager.getApplicationLabel(resolveInfo.activityInfo.applicationInfo);
+ }
+
+ private Bundle getFromTextClassifierExtra(@Nullable Bundle textLanguagesBundle) {
+ if (textLanguagesBundle != null) {
+ final Bundle bundle = new Bundle();
+ ExtrasUtils.putTextLanguagesExtra(bundle, textLanguagesBundle);
+ return bundle;
+ } else {
+ return Bundle.EMPTY;
+ }
+ }
+
+ private static PendingIntent createPendingIntent(
+ final Context context, final Intent intent, int requestCode) {
+ return PendingIntent.getActivity(
+ context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ }
+
+ /** Data class that holds the result. */
+ public static final class Result {
+ public final Intent resolvedIntent;
+ public final RemoteAction remoteAction;
+
+ public Result(Intent resolvedIntent, RemoteAction remoteAction) {
+ this.resolvedIntent = Preconditions.checkNotNull(resolvedIntent);
+ this.remoteAction = Preconditions.checkNotNull(remoteAction);
+ }
+ }
+
+ /**
+ * An object to choose a title from resolved info. If {@code null} is returned, {@link
+ * #titleWithEntity} will be used if it exists, {@link #titleWithoutEntity} otherwise.
+ */
+ public interface TitleChooser {
/**
- * Return the resolved result.
- *
- * @param context the context to resolve the result's intent and action
- * @param titleChooser for choosing an action title
- * @param textLanguagesBundle containing language detection information
+ * Picks a title from a {@link LabeledIntent} by looking into resolved info. {@code resolveInfo}
+ * is guaranteed to have a non-null {@code activityInfo}.
*/
@Nullable
- public Result resolve(
- Context context,
- @Nullable TitleChooser titleChooser,
- @Nullable Bundle textLanguagesBundle) {
- final PackageManager pm = context.getPackageManager();
- final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0);
-
- if (resolveInfo == null || resolveInfo.activityInfo == null) {
- TcLog.w(TAG, "resolveInfo or activityInfo is null");
- return null;
- }
- final String packageName = resolveInfo.activityInfo.packageName;
- final String className = resolveInfo.activityInfo.name;
- if (packageName == null || className == null) {
- TcLog.w(TAG, "packageName or className is null");
- return null;
- }
- Intent resolvedIntent = new Intent(intent);
- resolvedIntent.putExtra(
- TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER,
- getFromTextClassifierExtra(textLanguagesBundle));
- boolean shouldShowIcon = false;
- Icon icon = null;
- if (!"android".equals(packageName)) {
- // We only set the component name when the package name is not resolved to "android"
- // to workaround a bug that explicit intent with component name == ResolverActivity
- // can't be launched on keyguard.
- resolvedIntent.setComponent(new ComponentName(packageName, className));
- if (resolveInfo.activityInfo.getIconResource() != 0) {
- icon =
- Icon.createWithResource(
- packageName, resolveInfo.activityInfo.getIconResource());
- shouldShowIcon = true;
- }
- }
- if (icon == null) {
- // RemoteAction requires that there be an icon.
- icon = Icon.createWithResource(context, R.drawable.app_icon);
- }
- final PendingIntent pendingIntent =
- createPendingIntent(context, resolvedIntent, requestCode);
- titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
- CharSequence title = titleChooser.chooseTitle(this, resolveInfo);
- if (TextUtils.isEmpty(title)) {
- TcLog.w(TAG, "Custom titleChooser return null, fallback to the default titleChooser");
- title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
- }
- final RemoteAction action =
- new RemoteAction(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
- action.setShouldShowIcon(shouldShowIcon);
- return new Result(resolvedIntent, action);
- }
-
- private String resolveDescription(ResolveInfo resolveInfo, PackageManager packageManager) {
- if (!TextUtils.isEmpty(descriptionWithAppName)) {
- // Example string format of descriptionWithAppName: "Use %1$s to open map".
- String applicationName = getApplicationName(resolveInfo, packageManager);
- if (!TextUtils.isEmpty(applicationName)) {
- return String.format(descriptionWithAppName, applicationName);
- }
- }
- return description;
- }
-
- @Nullable
- private String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
- if (resolveInfo.activityInfo == null) {
- return null;
- }
- if ("android".equals(resolveInfo.activityInfo.packageName)) {
- return null;
- }
- if (resolveInfo.activityInfo.applicationInfo == null) {
- return null;
- }
- return (String)
- packageManager.getApplicationLabel(resolveInfo.activityInfo.applicationInfo);
- }
-
- private Bundle getFromTextClassifierExtra(@Nullable Bundle textLanguagesBundle) {
- if (textLanguagesBundle != null) {
- final Bundle bundle = new Bundle();
- ExtrasUtils.putTextLanguagesExtra(bundle, textLanguagesBundle);
- return bundle;
- } else {
- return Bundle.EMPTY;
- }
- }
-
- private static PendingIntent createPendingIntent(
- @NonNull final Context context, @NonNull final Intent intent, int requestCode) {
- return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
- }
-
- /** Data class that holds the result. */
- public static final class Result {
- public final Intent resolvedIntent;
- public final RemoteAction remoteAction;
-
- public Result(Intent resolvedIntent, RemoteAction remoteAction) {
- this.resolvedIntent = Preconditions.checkNotNull(resolvedIntent);
- this.remoteAction = Preconditions.checkNotNull(remoteAction);
- }
- }
-
- /**
- * An object to choose a title from resolved info. If {@code null} is returned, {@link
- * #titleWithEntity} will be used if it exists, {@link #titleWithoutEntity} otherwise.
- */
- public interface TitleChooser {
- /**
- * Picks a title from a {@link LabeledIntent} by looking into resolved info. {@code
- * resolveInfo} is guaranteed to have a non-null {@code activityInfo}.
- */
- @Nullable
- CharSequence chooseTitle(LabeledIntent labeledIntent, ResolveInfo resolveInfo);
- }
+ CharSequence chooseTitle(LabeledIntent labeledIntent, ResolveInfo resolveInfo);
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
index 4bac39a..c58df76 100644
--- a/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import static java.time.temporal.ChronoUnit.MILLIS;
@@ -28,15 +29,9 @@
import android.provider.CalendarContract;
import android.provider.ContactsContract;
import android.view.textclassifier.TextClassifier;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.R;
import com.android.textclassifier.TcLog;
-
import com.google.android.textclassifier.AnnotatorModel;
-
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.time.Instant;
@@ -44,256 +39,239 @@
import java.util.List;
import java.util.Locale;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
-/**
- * Creates intents based on the classification type.
- *
- * @hide
- */
+/** Creates intents based on the classification type. */
// TODO: Consider to support {@code descriptionWithAppName}.
public final class LegacyClassificationIntentFactory implements ClassificationIntentFactory {
- private static final String TAG = "LegacyClassificationIntentFactory";
- private static final long MIN_EVENT_FUTURE_MILLIS = TimeUnit.MINUTES.toMillis(5);
- private static final long DEFAULT_EVENT_DURATION = TimeUnit.HOURS.toMillis(1);
+ private static final String TAG = "LegacyIntentFactory";
+ private static final long MIN_EVENT_FUTURE_MILLIS = TimeUnit.MINUTES.toMillis(5);
+ private static final long DEFAULT_EVENT_DURATION = TimeUnit.HOURS.toMillis(1);
- // Sync with TextClassifier.TYPE_DICTIONARY.
- private static final String TYPE_DICTIONARY = "dictionary";
+ // Sync with TextClassifier.TYPE_DICTIONARY.
+ private static final String TYPE_DICTIONARY = "dictionary";
- @NonNull
- @Override
- public List<LabeledIntent> create(
- Context context,
- String text,
- boolean foreignText,
- @Nullable Instant referenceTime,
- AnnotatorModel.ClassificationResult classification) {
- final String type =
- classification != null
- ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
- : "";
- text = text.trim();
- final List<LabeledIntent> actions;
- switch (type) {
- case TextClassifier.TYPE_EMAIL:
- actions = createForEmail(context, text);
- break;
- case TextClassifier.TYPE_PHONE:
- actions = createForPhone(context, text);
- break;
- case TextClassifier.TYPE_ADDRESS:
- actions = createForAddress(context, text);
- break;
- case TextClassifier.TYPE_URL:
- actions = createForUrl(context, text);
- break;
- case TextClassifier.TYPE_DATE: // fall through
- case TextClassifier.TYPE_DATE_TIME:
- if (classification.getDatetimeResult() != null) {
- final Instant parsedTime =
- Instant.ofEpochMilli(classification.getDatetimeResult().getTimeMsUtc());
- actions = createForDatetime(context, type, referenceTime, parsedTime);
- } else {
- actions = new ArrayList<>();
- }
- break;
- case TextClassifier.TYPE_FLIGHT_NUMBER:
- actions = createForFlight(context, text);
- break;
- // case TextClassifier.TYPE_DICTIONARY:
- case TYPE_DICTIONARY:
- actions = createForDictionary(context, text);
- break;
- default:
- actions = new ArrayList<>();
- break;
+ @Override
+ public List<LabeledIntent> create(
+ Context context,
+ String text,
+ boolean foreignText,
+ @Nullable Instant referenceTime,
+ AnnotatorModel.ClassificationResult classification) {
+ final String type =
+ classification != null
+ ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
+ : "";
+ text = text.trim();
+ final List<LabeledIntent> actions;
+ switch (type) {
+ case TextClassifier.TYPE_EMAIL:
+ actions = createForEmail(context, text);
+ break;
+ case TextClassifier.TYPE_PHONE:
+ actions = createForPhone(context, text);
+ break;
+ case TextClassifier.TYPE_ADDRESS:
+ actions = createForAddress(context, text);
+ break;
+ case TextClassifier.TYPE_URL:
+ actions = createForUrl(context, text);
+ break;
+ case TextClassifier.TYPE_DATE: // fall through
+ case TextClassifier.TYPE_DATE_TIME:
+ if (classification.getDatetimeResult() != null) {
+ final Instant parsedTime =
+ Instant.ofEpochMilli(classification.getDatetimeResult().getTimeMsUtc());
+ actions = createForDatetime(context, type, referenceTime, parsedTime);
+ } else {
+ actions = new ArrayList<>();
}
- if (foreignText) {
- ClassificationIntentFactory.insertTranslateAction(actions, context, text);
- }
- return actions;
+ break;
+ case TextClassifier.TYPE_FLIGHT_NUMBER:
+ actions = createForFlight(context, text);
+ break;
+ // case TextClassifier.TYPE_DICTIONARY:
+ case TYPE_DICTIONARY:
+ actions = createForDictionary(context, text);
+ break;
+ default:
+ actions = new ArrayList<>();
+ break;
}
+ if (foreignText) {
+ ClassificationIntentFactory.insertTranslateAction(actions, context, text);
+ }
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForEmail(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.email),
- /* titleWithEntity */ null,
- context.getString(R.string.email_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_SENDTO)
- .setData(Uri.parse(String.format("mailto:%s", text))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- actions.add(
- new LabeledIntent(
- context.getString(R.string.add_contact),
- /* titleWithEntity */ null,
- context.getString(R.string.add_contact_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_INSERT_OR_EDIT)
- .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
- .putExtra(ContactsContract.Intents.Insert.EMAIL, text),
- text.hashCode()));
- return actions;
- }
+ private static List<LabeledIntent> createForEmail(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.email),
+ /* titleWithEntity */ null,
+ context.getString(R.string.email_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_SENDTO).setData(Uri.parse(String.format("mailto:%s", text))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.add_contact),
+ /* titleWithEntity */ null,
+ context.getString(R.string.add_contact_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_INSERT_OR_EDIT)
+ .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
+ .putExtra(ContactsContract.Intents.Insert.EMAIL, text),
+ text.hashCode()));
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForPhone(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- final UserManager userManager = context.getSystemService(UserManager.class);
- final Bundle userRestrictions =
- userManager != null ? userManager.getUserRestrictions() : new Bundle();
- if (!userRestrictions.getBoolean(UserManager.DISALLOW_OUTGOING_CALLS, false)) {
- actions.add(
- new LabeledIntent(
- context.getString(R.string.dial),
- /* titleWithEntity */ null,
- context.getString(R.string.dial_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_DIAL)
- .setData(Uri.parse(String.format("tel:%s", text))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- }
- actions.add(
- new LabeledIntent(
- context.getString(R.string.add_contact),
- /* titleWithEntity */ null,
- context.getString(R.string.add_contact_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_INSERT_OR_EDIT)
- .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
- .putExtra(ContactsContract.Intents.Insert.PHONE, text),
- text.hashCode()));
- if (!userRestrictions.getBoolean(UserManager.DISALLOW_SMS, false)) {
- actions.add(
- new LabeledIntent(
- context.getString(R.string.sms),
- /* titleWithEntity */ null,
- context.getString(R.string.sms_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_SENDTO)
- .setData(Uri.parse(String.format("smsto:%s", text))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- }
- return actions;
+ private static List<LabeledIntent> createForPhone(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ final UserManager userManager = context.getSystemService(UserManager.class);
+ final Bundle userRestrictions =
+ userManager != null ? userManager.getUserRestrictions() : new Bundle();
+ if (!userRestrictions.getBoolean(UserManager.DISALLOW_OUTGOING_CALLS, false)) {
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.dial),
+ /* titleWithEntity */ null,
+ context.getString(R.string.dial_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_DIAL).setData(Uri.parse(String.format("tel:%s", text))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
}
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.add_contact),
+ /* titleWithEntity */ null,
+ context.getString(R.string.add_contact_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_INSERT_OR_EDIT)
+ .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
+ .putExtra(ContactsContract.Intents.Insert.PHONE, text),
+ text.hashCode()));
+ if (!userRestrictions.getBoolean(UserManager.DISALLOW_SMS, false)) {
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.sms),
+ /* titleWithEntity */ null,
+ context.getString(R.string.sms_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_SENDTO).setData(Uri.parse(String.format("smsto:%s", text))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ }
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForAddress(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- try {
- final String encText = URLEncoder.encode(text, "UTF-8");
- actions.add(
- new LabeledIntent(
- context.getString(R.string.map),
- /* titleWithEntity */ null,
- context.getString(R.string.map_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_VIEW)
- .setData(Uri.parse(String.format("geo:0,0?q=%s", encText))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- } catch (UnsupportedEncodingException e) {
- TcLog.e(TAG, "Could not encode address", e);
- }
- return actions;
+ private static List<LabeledIntent> createForAddress(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ try {
+ final String encText = URLEncoder.encode(text, "UTF-8");
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.map),
+ /* titleWithEntity */ null,
+ context.getString(R.string.map_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_VIEW)
+ .setData(Uri.parse(String.format("geo:0,0?q=%s", encText))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ } catch (UnsupportedEncodingException e) {
+ TcLog.e(TAG, "Could not encode address", e);
}
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForUrl(Context context, String text) {
- if (Uri.parse(text).getScheme() == null) {
- text = "http://" + text;
- }
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.browse),
- /* titleWithEntity */ null,
- context.getString(R.string.browse_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_VIEW)
- .setDataAndNormalize(Uri.parse(text))
- .putExtra(Browser.EXTRA_APPLICATION_ID, context.getPackageName()),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- return actions;
+ private static List<LabeledIntent> createForUrl(Context context, String text) {
+ if (Uri.parse(text).getScheme() == null) {
+ text = "http://" + text;
}
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.browse),
+ /* titleWithEntity */ null,
+ context.getString(R.string.browse_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_VIEW)
+ .setDataAndNormalize(Uri.parse(text))
+ .putExtra(Browser.EXTRA_APPLICATION_ID, context.getPackageName()),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForDatetime(
- Context context, String type, @Nullable Instant referenceTime, Instant parsedTime) {
- if (referenceTime == null) {
- // If no reference time was given, use now.
- referenceTime = Instant.now();
- }
- List<LabeledIntent> actions = new ArrayList<>();
- actions.add(createCalendarViewIntent(context, parsedTime));
- final long millisUntilEvent = referenceTime.until(parsedTime, MILLIS);
- if (millisUntilEvent > MIN_EVENT_FUTURE_MILLIS) {
- actions.add(createCalendarCreateEventIntent(context, parsedTime, type));
- }
- return actions;
+ private static List<LabeledIntent> createForDatetime(
+ Context context, String type, @Nullable Instant referenceTime, Instant parsedTime) {
+ if (referenceTime == null) {
+ // If no reference time was given, use now.
+ referenceTime = Instant.now();
}
+ List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(createCalendarViewIntent(context, parsedTime));
+ final long millisUntilEvent = referenceTime.until(parsedTime, MILLIS);
+ if (millisUntilEvent > MIN_EVENT_FUTURE_MILLIS) {
+ actions.add(createCalendarCreateEventIntent(context, parsedTime, type));
+ }
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForFlight(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.view_flight),
- /* titleWithEntity */ null,
- context.getString(R.string.view_flight_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_WEB_SEARCH).putExtra(SearchManager.QUERY, text),
- text.hashCode()));
- return actions;
- }
+ private static List<LabeledIntent> createForFlight(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.view_flight),
+ /* titleWithEntity */ null,
+ context.getString(R.string.view_flight_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_WEB_SEARCH).putExtra(SearchManager.QUERY, text),
+ text.hashCode()));
+ return actions;
+ }
- @NonNull
- private static LabeledIntent createCalendarViewIntent(Context context, Instant parsedTime) {
- Uri.Builder builder = CalendarContract.CONTENT_URI.buildUpon();
- builder.appendPath("time");
- ContentUris.appendId(builder, parsedTime.toEpochMilli());
- return new LabeledIntent(
- context.getString(R.string.view_calendar),
- /* titleWithEntity */ null,
- context.getString(R.string.view_calendar_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_VIEW).setData(builder.build()),
- LabeledIntent.DEFAULT_REQUEST_CODE);
- }
+ private static LabeledIntent createCalendarViewIntent(Context context, Instant parsedTime) {
+ Uri.Builder builder = CalendarContract.CONTENT_URI.buildUpon();
+ builder.appendPath("time");
+ ContentUris.appendId(builder, parsedTime.toEpochMilli());
+ return new LabeledIntent(
+ context.getString(R.string.view_calendar),
+ /* titleWithEntity */ null,
+ context.getString(R.string.view_calendar_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_VIEW).setData(builder.build()),
+ LabeledIntent.DEFAULT_REQUEST_CODE);
+ }
- @NonNull
- private static LabeledIntent createCalendarCreateEventIntent(
- Context context, Instant parsedTime, String type) {
- final boolean isAllDay = TextClassifier.TYPE_DATE.equals(type);
- return new LabeledIntent(
- context.getString(R.string.add_calendar_event),
- /* titleWithEntity */ null,
- context.getString(R.string.add_calendar_event_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_INSERT)
- .setData(CalendarContract.Events.CONTENT_URI)
- .putExtra(CalendarContract.EXTRA_EVENT_ALL_DAY, isAllDay)
- .putExtra(
- CalendarContract.EXTRA_EVENT_BEGIN_TIME, parsedTime.toEpochMilli())
- .putExtra(
- CalendarContract.EXTRA_EVENT_END_TIME,
- parsedTime.toEpochMilli() + DEFAULT_EVENT_DURATION),
- parsedTime.hashCode());
- }
+ private static LabeledIntent createCalendarCreateEventIntent(
+ Context context, Instant parsedTime, String type) {
+ final boolean isAllDay = TextClassifier.TYPE_DATE.equals(type);
+ return new LabeledIntent(
+ context.getString(R.string.add_calendar_event),
+ /* titleWithEntity */ null,
+ context.getString(R.string.add_calendar_event_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_INSERT)
+ .setData(CalendarContract.Events.CONTENT_URI)
+ .putExtra(CalendarContract.EXTRA_EVENT_ALL_DAY, isAllDay)
+ .putExtra(CalendarContract.EXTRA_EVENT_BEGIN_TIME, parsedTime.toEpochMilli())
+ .putExtra(
+ CalendarContract.EXTRA_EVENT_END_TIME,
+ parsedTime.toEpochMilli() + DEFAULT_EVENT_DURATION),
+ parsedTime.hashCode());
+ }
- @NonNull
- private static List<LabeledIntent> createForDictionary(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.define),
- /* titleWithEntity */ null,
- context.getString(R.string.define_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_DEFINE).putExtra(Intent.EXTRA_TEXT, text),
- text.hashCode()));
- return actions;
- }
+ private static List<LabeledIntent> createForDictionary(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.define),
+ /* titleWithEntity */ null,
+ context.getString(R.string.define_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_DEFINE).putExtra(Intent.EXTRA_TEXT, text),
+ text.hashCode()));
+ return actions;
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
index 53f5bf4..f5ef577 100644
--- a/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,69 +13,60 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.content.Context;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.TcLog;
-
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.RemoteActionTemplate;
-
+import com.google.common.base.Preconditions;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
+import javax.annotation.Nullable;
/**
* Creates intents based on {@link RemoteActionTemplate} objects for a ClassificationResult.
*
* @hide
*/
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public final class TemplateClassificationIntentFactory implements ClassificationIntentFactory {
- private static final String TAG = "TemplateClassificationIntentFactory";
- private final TemplateIntentFactory mTemplateIntentFactory;
- private final ClassificationIntentFactory mFallback;
+ private static final String TAG = "TemplateIntentFactory";
+ private final TemplateIntentFactory templateIntentFactory;
+ private final ClassificationIntentFactory fallback;
- public TemplateClassificationIntentFactory(
- TemplateIntentFactory templateIntentFactory, ClassificationIntentFactory fallback) {
- mTemplateIntentFactory = Preconditions.checkNotNull(templateIntentFactory);
- mFallback = Preconditions.checkNotNull(fallback);
- }
+ public TemplateClassificationIntentFactory(
+ TemplateIntentFactory templateIntentFactory, ClassificationIntentFactory fallback) {
+ this.templateIntentFactory = Preconditions.checkNotNull(templateIntentFactory);
+ this.fallback = Preconditions.checkNotNull(fallback);
+ }
- /**
- * Returns a list of {@link LabeledIntent} that are constructed from the classification result.
- */
- @NonNull
- @Override
- public List<LabeledIntent> create(
- Context context,
- String text,
- boolean foreignText,
- @Nullable Instant referenceTime,
- @Nullable AnnotatorModel.ClassificationResult classification) {
- if (classification == null) {
- return Collections.emptyList();
- }
- RemoteActionTemplate[] remoteActionTemplates = classification.getRemoteActionTemplates();
- if (remoteActionTemplates == null) {
- // RemoteActionTemplate is missing, fallback.
- TcLog.w(
- TAG,
- "RemoteActionTemplate is missing, fallback to"
- + " LegacyClassificationIntentFactory.");
- return mFallback.create(context, text, foreignText, referenceTime, classification);
- }
- final List<LabeledIntent> labeledIntents =
- mTemplateIntentFactory.create(remoteActionTemplates);
- if (foreignText) {
- ClassificationIntentFactory.insertTranslateAction(labeledIntents, context, text.trim());
- }
- return labeledIntents;
+ /**
+ * Returns a list of {@link LabeledIntent} that are constructed from the classification result.
+ */
+ @Override
+ public List<LabeledIntent> create(
+ Context context,
+ String text,
+ boolean foreignText,
+ @Nullable Instant referenceTime,
+ @Nullable AnnotatorModel.ClassificationResult classification) {
+ if (classification == null) {
+ return Collections.emptyList();
}
+ RemoteActionTemplate[] remoteActionTemplates = classification.getRemoteActionTemplates();
+ if (remoteActionTemplates == null) {
+ // RemoteActionTemplate is missing, fallback.
+ TcLog.w(
+ TAG,
+ "RemoteActionTemplate is missing, fallback to" + " LegacyClassificationIntentFactory.");
+ return fallback.create(context, text, foreignText, referenceTime, classification);
+ }
+ final List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
+ if (foreignText) {
+ ClassificationIntentFactory.insertTranslateAction(labeledIntents, context, text.trim());
+ }
+ return labeledIntents;
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java b/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
index 57f608a..718c988 100644
--- a/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,144 +13,136 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.content.Intent;
import android.net.Uri;
import android.os.Bundle;
import android.text.TextUtils;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-
import com.android.textclassifier.TcLog;
-
import com.google.android.textclassifier.NamedVariant;
import com.google.android.textclassifier.RemoteActionTemplate;
-
import java.util.ArrayList;
import java.util.List;
+import javax.annotation.Nullable;
/**
* Creates intents based on {@link RemoteActionTemplate} objects.
*
* @hide
*/
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public final class TemplateIntentFactory {
- private static final String TAG = "TemplateIntentFactory";
+ private static final String TAG = "TemplateIntentFactory";
- /** Constructs and returns a list of {@link LabeledIntent} based on the given templates. */
- @Nullable
- public List<LabeledIntent> create(@NonNull RemoteActionTemplate[] remoteActionTemplates) {
- if (remoteActionTemplates.length == 0) {
- return new ArrayList<>();
- }
- final List<LabeledIntent> labeledIntents = new ArrayList<>();
- for (RemoteActionTemplate remoteActionTemplate : remoteActionTemplates) {
- if (!isValidTemplate(remoteActionTemplate)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate skipped.");
- continue;
- }
- labeledIntents.add(
- new LabeledIntent(
- remoteActionTemplate.titleWithoutEntity,
- remoteActionTemplate.titleWithEntity,
- remoteActionTemplate.description,
- remoteActionTemplate.descriptionWithAppName,
- createIntent(remoteActionTemplate),
- remoteActionTemplate.requestCode == null
- ? LabeledIntent.DEFAULT_REQUEST_CODE
- : remoteActionTemplate.requestCode));
- }
- return labeledIntents;
+ /** Constructs and returns a list of {@link LabeledIntent} based on the given templates. */
+ @Nullable
+ public List<LabeledIntent> create(RemoteActionTemplate[] remoteActionTemplates) {
+ if (remoteActionTemplates.length == 0) {
+ return new ArrayList<>();
}
+ final List<LabeledIntent> labeledIntents = new ArrayList<>();
+ for (RemoteActionTemplate remoteActionTemplate : remoteActionTemplates) {
+ if (!isValidTemplate(remoteActionTemplate)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate skipped.");
+ continue;
+ }
+ labeledIntents.add(
+ new LabeledIntent(
+ remoteActionTemplate.titleWithoutEntity,
+ remoteActionTemplate.titleWithEntity,
+ remoteActionTemplate.description,
+ remoteActionTemplate.descriptionWithAppName,
+ createIntent(remoteActionTemplate),
+ remoteActionTemplate.requestCode == null
+ ? LabeledIntent.DEFAULT_REQUEST_CODE
+ : remoteActionTemplate.requestCode));
+ }
+ return labeledIntents;
+ }
- private static boolean isValidTemplate(@Nullable RemoteActionTemplate remoteActionTemplate) {
- if (remoteActionTemplate == null) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: is null");
- return false;
- }
- if (TextUtils.isEmpty(remoteActionTemplate.titleWithEntity)
- && TextUtils.isEmpty(remoteActionTemplate.titleWithoutEntity)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: title is null");
- return false;
- }
- if (TextUtils.isEmpty(remoteActionTemplate.description)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: description is null");
- return false;
- }
- if (!TextUtils.isEmpty(remoteActionTemplate.packageName)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: package name is set");
- return false;
- }
- if (TextUtils.isEmpty(remoteActionTemplate.action)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: intent action not set");
- return false;
- }
- return true;
+ private static boolean isValidTemplate(@Nullable RemoteActionTemplate remoteActionTemplate) {
+ if (remoteActionTemplate == null) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: is null");
+ return false;
}
+ if (TextUtils.isEmpty(remoteActionTemplate.titleWithEntity)
+ && TextUtils.isEmpty(remoteActionTemplate.titleWithoutEntity)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: title is null");
+ return false;
+ }
+ if (TextUtils.isEmpty(remoteActionTemplate.description)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: description is null");
+ return false;
+ }
+ if (!TextUtils.isEmpty(remoteActionTemplate.packageName)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: package name is set");
+ return false;
+ }
+ if (TextUtils.isEmpty(remoteActionTemplate.action)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: intent action not set");
+ return false;
+ }
+ return true;
+ }
- private static Intent createIntent(RemoteActionTemplate remoteActionTemplate) {
- final Intent intent = new Intent(remoteActionTemplate.action);
- final Uri uri =
- TextUtils.isEmpty(remoteActionTemplate.data)
- ? null
- : Uri.parse(remoteActionTemplate.data).normalizeScheme();
- final String type =
- TextUtils.isEmpty(remoteActionTemplate.type)
- ? null
- : Intent.normalizeMimeType(remoteActionTemplate.type);
- intent.setDataAndType(uri, type);
- intent.setFlags(remoteActionTemplate.flags == null ? 0 : remoteActionTemplate.flags);
- if (remoteActionTemplate.category != null) {
- for (String category : remoteActionTemplate.category) {
- if (category != null) {
- intent.addCategory(category);
- }
- }
+ private static Intent createIntent(RemoteActionTemplate remoteActionTemplate) {
+ final Intent intent = new Intent(remoteActionTemplate.action);
+ final Uri uri =
+ TextUtils.isEmpty(remoteActionTemplate.data)
+ ? null
+ : Uri.parse(remoteActionTemplate.data).normalizeScheme();
+ final String type =
+ TextUtils.isEmpty(remoteActionTemplate.type)
+ ? null
+ : Intent.normalizeMimeType(remoteActionTemplate.type);
+ intent.setDataAndType(uri, type);
+ intent.setFlags(remoteActionTemplate.flags == null ? 0 : remoteActionTemplate.flags);
+ if (remoteActionTemplate.category != null) {
+ for (String category : remoteActionTemplate.category) {
+ if (category != null) {
+ intent.addCategory(category);
}
- intent.putExtras(nameVariantsToBundle(remoteActionTemplate.extras));
- return intent;
+ }
}
+ intent.putExtras(nameVariantsToBundle(remoteActionTemplate.extras));
+ return intent;
+ }
- /** Converts an array of {@link NamedVariant} to a Bundle and returns it. */
- public static Bundle nameVariantsToBundle(@Nullable NamedVariant[] namedVariants) {
- if (namedVariants == null) {
- return Bundle.EMPTY;
- }
- Bundle bundle = new Bundle();
- for (NamedVariant namedVariant : namedVariants) {
- if (namedVariant == null) {
- continue;
- }
- switch (namedVariant.getType()) {
- case NamedVariant.TYPE_INT:
- bundle.putInt(namedVariant.getName(), namedVariant.getInt());
- break;
- case NamedVariant.TYPE_LONG:
- bundle.putLong(namedVariant.getName(), namedVariant.getLong());
- break;
- case NamedVariant.TYPE_FLOAT:
- bundle.putFloat(namedVariant.getName(), namedVariant.getFloat());
- break;
- case NamedVariant.TYPE_DOUBLE:
- bundle.putDouble(namedVariant.getName(), namedVariant.getDouble());
- break;
- case NamedVariant.TYPE_BOOL:
- bundle.putBoolean(namedVariant.getName(), namedVariant.getBool());
- break;
- case NamedVariant.TYPE_STRING:
- bundle.putString(namedVariant.getName(), namedVariant.getString());
- break;
- default:
- TcLog.w(
- TAG,
- "Unsupported type found in nameVariantsToBundle : "
- + namedVariant.getType());
- }
- }
- return bundle;
+ /** Converts an array of {@link NamedVariant} to a Bundle and returns it. */
+ public static Bundle nameVariantsToBundle(@Nullable NamedVariant[] namedVariants) {
+ if (namedVariants == null) {
+ return Bundle.EMPTY;
}
+ Bundle bundle = new Bundle();
+ for (NamedVariant namedVariant : namedVariants) {
+ if (namedVariant == null) {
+ continue;
+ }
+ switch (namedVariant.getType()) {
+ case NamedVariant.TYPE_INT:
+ bundle.putInt(namedVariant.getName(), namedVariant.getInt());
+ break;
+ case NamedVariant.TYPE_LONG:
+ bundle.putLong(namedVariant.getName(), namedVariant.getLong());
+ break;
+ case NamedVariant.TYPE_FLOAT:
+ bundle.putFloat(namedVariant.getName(), namedVariant.getFloat());
+ break;
+ case NamedVariant.TYPE_DOUBLE:
+ bundle.putDouble(namedVariant.getName(), namedVariant.getDouble());
+ break;
+ case NamedVariant.TYPE_BOOL:
+ bundle.putBoolean(namedVariant.getName(), namedVariant.getBool());
+ break;
+ case NamedVariant.TYPE_STRING:
+ bundle.putString(namedVariant.getName(), namedVariant.getString());
+ break;
+ default:
+ TcLog.w(
+ TAG, "Unsupported type found in nameVariantsToBundle : " + namedVariant.getType());
+ }
+ }
+ return bundle;
+ }
}
diff --git a/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java b/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java
index d06104c..5da8e93 100644
--- a/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java
+++ b/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,128 +19,126 @@
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextClassifierEvent;
import android.view.textclassifier.TextLinks;
-
-import androidx.annotation.Nullable;
import androidx.collection.ArrayMap;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.TcLog;
import com.android.textclassifier.TextClassifierStatsLog;
-
+import com.google.common.base.Preconditions;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.UUID;
+import javax.annotation.Nullable;
/** A helper for logging calls to generateLinks. */
public final class GenerateLinksLogger {
- private static final String LOG_TAG = "GenerateLinksLogger";
+ private static final String LOG_TAG = "GenerateLinksLogger";
- private final Random mRng;
- private final int mSampleRate;
+ private final Random rng;
+ private final int sampleRate;
- /**
- * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
- * chance that a call to logGenerateLinks results in an event being written). To write all
- * events, pass 1.
- */
- public GenerateLinksLogger(int sampleRate) {
- mSampleRate = sampleRate;
- mRng = new Random();
+ /**
+ * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
+ * chance that a call to logGenerateLinks results in an event being written). To write all
+ * events, pass 1.
+ */
+ public GenerateLinksLogger(int sampleRate) {
+ this.sampleRate = sampleRate;
+ rng = new Random();
+ }
+
+ /** Logs statistics about a call to generateLinks. */
+ public void logGenerateLinks(
+ CharSequence text, TextLinks links, String callingPackageName, long latencyMs) {
+ Preconditions.checkNotNull(text);
+ Preconditions.checkNotNull(links);
+ Preconditions.checkNotNull(callingPackageName);
+ if (!shouldLog()) {
+ return;
}
- /** Logs statistics about a call to generateLinks. */
- public void logGenerateLinks(
- CharSequence text, TextLinks links, String callingPackageName, long latencyMs) {
- Preconditions.checkNotNull(text);
- Preconditions.checkNotNull(links);
- Preconditions.checkNotNull(callingPackageName);
- if (!shouldLog()) {
- return;
- }
-
- // Always populate the total stats, and per-entity stats for each entity type detected.
- final LinkifyStats totalStats = new LinkifyStats();
- final Map<String, LinkifyStats> perEntityTypeStats = new ArrayMap<>();
- for (TextLinks.TextLink link : links.getLinks()) {
- if (link.getEntityCount() == 0) continue;
- final String entityType = link.getEntity(0);
- if (entityType == null
- || TextClassifier.TYPE_OTHER.equals(entityType)
- || TextClassifier.TYPE_UNKNOWN.equals(entityType)) {
- continue;
- }
- totalStats.countLink(link);
- perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link);
- }
-
- final String callId = UUID.randomUUID().toString();
- writeStats(callId, callingPackageName, null, totalStats, text, latencyMs);
- for (Map.Entry<String, LinkifyStats> entry : perEntityTypeStats.entrySet()) {
- writeStats(
- callId, callingPackageName, entry.getKey(), entry.getValue(), text, latencyMs);
- }
+ // Always populate the total stats, and per-entity stats for each entity type detected.
+ final LinkifyStats totalStats = new LinkifyStats();
+ final Map<String, LinkifyStats> perEntityTypeStats = new ArrayMap<>();
+ for (TextLinks.TextLink link : links.getLinks()) {
+ if (link.getEntityCount() == 0) {
+ continue;
+ }
+ final String entityType = link.getEntity(0);
+ if (entityType == null
+ || TextClassifier.TYPE_OTHER.equals(entityType)
+ || TextClassifier.TYPE_UNKNOWN.equals(entityType)) {
+ continue;
+ }
+ totalStats.countLink(link);
+ perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link);
}
- /**
- * Returns whether this particular event should be logged.
- *
- * <p>Sampling is used to reduce the amount of logging data generated.
- */
- private boolean shouldLog() {
- if (mSampleRate <= 1) {
- return true;
- } else {
- return mRng.nextInt(mSampleRate) == 0;
- }
+ final String callId = UUID.randomUUID().toString();
+ writeStats(callId, callingPackageName, null, totalStats, text, latencyMs);
+ for (Map.Entry<String, LinkifyStats> entry : perEntityTypeStats.entrySet()) {
+ writeStats(callId, callingPackageName, entry.getKey(), entry.getValue(), text, latencyMs);
}
+ }
- /** Writes a log event for the given stats. */
- private void writeStats(
- String callId,
- String callingPackageName,
- @Nullable String entityType,
- LinkifyStats stats,
- CharSequence text,
- long latencyMs) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
- callId,
- TextClassifierEvent.TYPE_LINKS_GENERATED,
- /*modelName=*/ null,
- TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN,
- /*eventIndex=*/ 0,
- entityType,
- stats.mNumLinks,
- stats.mNumLinksTextLength,
- text.length(),
- latencyMs,
- callingPackageName);
- if (TcLog.ENABLE_FULL_LOGGING) {
- TcLog.v(
- LOG_TAG,
- String.format(
- Locale.US,
- "%s:%s %d links (%d/%d chars) %dms %s",
- callId,
- entityType,
- stats.mNumLinks,
- stats.mNumLinksTextLength,
- text.length(),
- latencyMs,
- callingPackageName));
- }
+ /**
+ * Returns whether this particular event should be logged.
+ *
+ * <p>Sampling is used to reduce the amount of logging data generated.
+ */
+ private boolean shouldLog() {
+ if (sampleRate <= 1) {
+ return true;
+ } else {
+ return rng.nextInt(sampleRate) == 0;
}
+ }
- /** Helper class for storing per-entity type statistics. */
- private static final class LinkifyStats {
- int mNumLinks;
- int mNumLinksTextLength;
-
- void countLink(TextLinks.TextLink link) {
- mNumLinks += 1;
- mNumLinksTextLength += link.getEnd() - link.getStart();
- }
+ /** Writes a log event for the given stats. */
+ private static void writeStats(
+ String callId,
+ String callingPackageName,
+ @Nullable String entityType,
+ LinkifyStats stats,
+ CharSequence text,
+ long latencyMs) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
+ callId,
+ TextClassifierEvent.TYPE_LINKS_GENERATED,
+ /*modelName=*/ null,
+ TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN,
+ /*eventIndex=*/ 0,
+ entityType,
+ stats.numLinks,
+ stats.numLinksTextLength,
+ text.length(),
+ latencyMs,
+ callingPackageName);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ LOG_TAG,
+ String.format(
+ Locale.US,
+ "%s:%s %d links (%d/%d chars) %dms %s",
+ callId,
+ entityType,
+ stats.numLinks,
+ stats.numLinksTextLength,
+ text.length(),
+ latencyMs,
+ callingPackageName));
}
+ }
+
+ /** Helper class for storing per-entity type statistics. */
+ private static final class LinkifyStats {
+ int numLinks;
+ int numLinksTextLength;
+
+ void countLink(TextLinks.TextLink link) {
+ numLinks += 1;
+ numLinksTextLength += link.getEnd() - link.getStart();
+ }
+ }
}
diff --git a/java/src/com/android/textclassifier/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/logging/ResultIdUtils.java
index de21959..c0f8b2a 100644
--- a/java/src/com/android/textclassifier/logging/ResultIdUtils.java
+++ b/java/src/com/android/textclassifier/logging/ResultIdUtils.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,59 +13,59 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.logging;
//
import android.content.Context;
-
-import androidx.annotation.Nullable;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.StringJoiner;
+import javax.annotation.Nullable;
/** Provide utils to generate and parse the result id. */
public final class ResultIdUtils {
- private static final String LOG_TAG = "ResultIdUtils";
- private static final String CLASSIFIER_ID = "androidtc";
+ private static final String CLASSIFIER_ID = "androidtc";
- /** Creates a string id that may be used to identify a TextClassifier result. */
- public static String createId(
- Context context,
- String text,
- int start,
- int end,
- int modelVersion,
- List<Locale> modelLocales) {
- Preconditions.checkNotNull(text);
- Preconditions.checkNotNull(context);
- Preconditions.checkNotNull(modelLocales);
- final int hash = Objects.hash(text, start, end, context.getPackageName());
- return createId(modelVersion, modelLocales, hash);
- }
+ /** Creates a string id that may be used to identify a TextClassifier result. */
+ public static String createId(
+ Context context,
+ String text,
+ int start,
+ int end,
+ int modelVersion,
+ List<Locale> modelLocales) {
+ Preconditions.checkNotNull(text);
+ Preconditions.checkNotNull(context);
+ Preconditions.checkNotNull(modelLocales);
+ final int hash = Objects.hash(text, start, end, context.getPackageName());
+ return createId(modelVersion, modelLocales, hash);
+ }
- /** Creates a string id that may be used to identify a TextClassifier result. */
- public static String createId(int modelVersion, List<Locale> modelLocales, int hash) {
- final StringJoiner localesJoiner = new StringJoiner(",");
- for (Locale locale : modelLocales) {
- localesJoiner.add(locale.toLanguageTag());
- }
- final String modelName =
- String.format(Locale.US, "%s_v%d", localesJoiner.toString(), modelVersion);
- return String.format(Locale.US, "%s|%s|%d", CLASSIFIER_ID, modelName, hash);
+ /** Creates a string id that may be used to identify a TextClassifier result. */
+ public static String createId(int modelVersion, List<Locale> modelLocales, int hash) {
+ final StringJoiner localesJoiner = new StringJoiner(",");
+ for (Locale locale : modelLocales) {
+ localesJoiner.add(locale.toLanguageTag());
}
+ final String modelName =
+ String.format(Locale.US, "%s_v%d", localesJoiner.toString(), modelVersion);
+ return String.format(Locale.US, "%s|%s|%d", CLASSIFIER_ID, modelName, hash);
+ }
- static String getModelName(@Nullable String signature) {
- if (signature == null) {
- return "";
- }
- final int start = signature.indexOf("|") + 1;
- final int end = signature.indexOf("|", start);
- if (start >= 1 && end >= start) {
- return signature.substring(start, end);
- }
- return "";
+ static String getModelName(@Nullable String signature) {
+ if (signature == null) {
+ return "";
}
+ final int start = signature.indexOf("|") + 1;
+ final int end = signature.indexOf("|", start);
+ if (start >= 1 && end >= start) {
+ return signature.substring(start, end);
+ }
+ return "";
+ }
+
+ private ResultIdUtils() {}
}
diff --git a/java/src/com/android/textclassifier/logging/SelectionEventConverter.java b/java/src/com/android/textclassifier/logging/SelectionEventConverter.java
index 2d492a9..05fdf9f 100644
--- a/java/src/com/android/textclassifier/logging/SelectionEventConverter.java
+++ b/java/src/com/android/textclassifier/logging/SelectionEventConverter.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,84 +19,85 @@
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassifierEvent;
-
-import androidx.annotation.Nullable;
+import javax.annotation.Nullable;
/** Helper class to convert a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
public final class SelectionEventConverter {
- /** Converts a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
- @Nullable
- public static TextClassifierEvent toTextClassifierEvent(SelectionEvent selectionEvent) {
- TextClassificationContext textClassificationContext = null;
- if (selectionEvent.getPackageName() != null && selectionEvent.getWidgetType() != null) {
- textClassificationContext =
- new TextClassificationContext.Builder(
- selectionEvent.getPackageName(), selectionEvent.getWidgetType())
- .setWidgetVersion(selectionEvent.getWidgetVersion())
- .build();
- }
- if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_LINK) {
- return new TextClassifierEvent.TextLinkifyEvent.Builder(
- convertEventType(selectionEvent.getEventType()))
- .setEventContext(textClassificationContext)
- .setResultId(selectionEvent.getResultId())
- .setEventIndex(selectionEvent.getEventIndex())
- .setEntityTypes(selectionEvent.getEntityType())
- .build();
- }
- if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_MANUAL) {
- return new TextClassifierEvent.TextSelectionEvent.Builder(
- convertEventType(selectionEvent.getEventType()))
- .setEventContext(textClassificationContext)
- .setResultId(selectionEvent.getResultId())
- .setEventIndex(selectionEvent.getEventIndex())
- .setEntityTypes(selectionEvent.getEntityType())
- .setRelativeWordStartIndex(selectionEvent.getStart())
- .setRelativeWordEndIndex(selectionEvent.getEnd())
- .setRelativeSuggestedWordStartIndex(selectionEvent.getSmartStart())
- .setRelativeSuggestedWordEndIndex(selectionEvent.getSmartEnd())
- .build();
- }
- return null;
+ /** Converts a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
+ @Nullable
+ public static TextClassifierEvent toTextClassifierEvent(SelectionEvent selectionEvent) {
+ TextClassificationContext textClassificationContext = null;
+ if (selectionEvent.getPackageName() != null && selectionEvent.getWidgetType() != null) {
+ textClassificationContext =
+ new TextClassificationContext.Builder(
+ selectionEvent.getPackageName(), selectionEvent.getWidgetType())
+ .setWidgetVersion(selectionEvent.getWidgetVersion())
+ .build();
}
+ if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_LINK) {
+ return new TextClassifierEvent.TextLinkifyEvent.Builder(
+ convertEventType(selectionEvent.getEventType()))
+ .setEventContext(textClassificationContext)
+ .setResultId(selectionEvent.getResultId())
+ .setEventIndex(selectionEvent.getEventIndex())
+ .setEntityTypes(selectionEvent.getEntityType())
+ .build();
+ }
+ if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_MANUAL) {
+ return new TextClassifierEvent.TextSelectionEvent.Builder(
+ convertEventType(selectionEvent.getEventType()))
+ .setEventContext(textClassificationContext)
+ .setResultId(selectionEvent.getResultId())
+ .setEventIndex(selectionEvent.getEventIndex())
+ .setEntityTypes(selectionEvent.getEntityType())
+ .setRelativeWordStartIndex(selectionEvent.getStart())
+ .setRelativeWordEndIndex(selectionEvent.getEnd())
+ .setRelativeSuggestedWordStartIndex(selectionEvent.getSmartStart())
+ .setRelativeSuggestedWordEndIndex(selectionEvent.getSmartEnd())
+ .build();
+ }
+ return null;
+ }
- private static int convertEventType(int eventType) {
- switch (eventType) {
- case SelectionEvent.EVENT_SELECTION_STARTED:
- return TextClassifierEvent.TYPE_SELECTION_STARTED;
- case SelectionEvent.EVENT_SELECTION_MODIFIED:
- return TextClassifierEvent.TYPE_SELECTION_MODIFIED;
- case SelectionEvent.EVENT_SMART_SELECTION_SINGLE:
- return SelectionEvent.EVENT_SMART_SELECTION_SINGLE;
- case SelectionEvent.EVENT_SMART_SELECTION_MULTI:
- return SelectionEvent.EVENT_SMART_SELECTION_MULTI;
- case SelectionEvent.EVENT_AUTO_SELECTION:
- return SelectionEvent.EVENT_AUTO_SELECTION;
- case SelectionEvent.ACTION_OVERTYPE:
- return TextClassifierEvent.TYPE_OVERTYPE;
- case SelectionEvent.ACTION_COPY:
- return TextClassifierEvent.TYPE_COPY_ACTION;
- case SelectionEvent.ACTION_PASTE:
- return TextClassifierEvent.TYPE_PASTE_ACTION;
- case SelectionEvent.ACTION_CUT:
- return TextClassifierEvent.TYPE_CUT_ACTION;
- case SelectionEvent.ACTION_SHARE:
- return TextClassifierEvent.TYPE_SHARE_ACTION;
- case SelectionEvent.ACTION_SMART_SHARE:
- return TextClassifierEvent.TYPE_SMART_ACTION;
- case SelectionEvent.ACTION_DRAG:
- return TextClassifierEvent.TYPE_SELECTION_DRAG;
- case SelectionEvent.ACTION_ABANDON:
- return TextClassifierEvent.TYPE_SELECTION_DESTROYED;
- case SelectionEvent.ACTION_OTHER:
- return TextClassifierEvent.TYPE_OTHER_ACTION;
- case SelectionEvent.ACTION_SELECT_ALL:
- return TextClassifierEvent.TYPE_SELECT_ALL;
- case SelectionEvent.ACTION_RESET:
- return TextClassifierEvent.TYPE_SELECTION_RESET;
- default:
- return 0;
- }
+ private static int convertEventType(int eventType) {
+ switch (eventType) {
+ case SelectionEvent.EVENT_SELECTION_STARTED:
+ return TextClassifierEvent.TYPE_SELECTION_STARTED;
+ case SelectionEvent.EVENT_SELECTION_MODIFIED:
+ return TextClassifierEvent.TYPE_SELECTION_MODIFIED;
+ case SelectionEvent.EVENT_SMART_SELECTION_SINGLE:
+ return SelectionEvent.EVENT_SMART_SELECTION_SINGLE;
+ case SelectionEvent.EVENT_SMART_SELECTION_MULTI:
+ return SelectionEvent.EVENT_SMART_SELECTION_MULTI;
+ case SelectionEvent.EVENT_AUTO_SELECTION:
+ return SelectionEvent.EVENT_AUTO_SELECTION;
+ case SelectionEvent.ACTION_OVERTYPE:
+ return TextClassifierEvent.TYPE_OVERTYPE;
+ case SelectionEvent.ACTION_COPY:
+ return TextClassifierEvent.TYPE_COPY_ACTION;
+ case SelectionEvent.ACTION_PASTE:
+ return TextClassifierEvent.TYPE_PASTE_ACTION;
+ case SelectionEvent.ACTION_CUT:
+ return TextClassifierEvent.TYPE_CUT_ACTION;
+ case SelectionEvent.ACTION_SHARE:
+ return TextClassifierEvent.TYPE_SHARE_ACTION;
+ case SelectionEvent.ACTION_SMART_SHARE:
+ return TextClassifierEvent.TYPE_SMART_ACTION;
+ case SelectionEvent.ACTION_DRAG:
+ return TextClassifierEvent.TYPE_SELECTION_DRAG;
+ case SelectionEvent.ACTION_ABANDON:
+ return TextClassifierEvent.TYPE_SELECTION_DESTROYED;
+ case SelectionEvent.ACTION_OTHER:
+ return TextClassifierEvent.TYPE_OTHER_ACTION;
+ case SelectionEvent.ACTION_SELECT_ALL:
+ return TextClassifierEvent.TYPE_SELECT_ALL;
+ case SelectionEvent.ACTION_RESET:
+ return TextClassifierEvent.TYPE_SELECTION_RESET;
+ default:
+ return 0;
}
+ }
+
+ private SelectionEventConverter() {}
}
diff --git a/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java
index 29e38a1..a112fdf 100644
--- a/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java
+++ b/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,203 +13,203 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.logging;
import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextClassifierEvent;
-
-import androidx.annotation.Nullable;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.TcLog;
import com.android.textclassifier.TextClassifierStatsLog;
+import com.google.common.base.Preconditions;
+import javax.annotation.Nullable;
/** Logs {@link android.view.textclassifier.TextClassifierEvent}. */
public final class TextClassifierEventLogger {
- private static final String TAG = "TextClassifierEventLogger";
+ private static final String TAG = "TCEventLogger";
- /** Emits a text classifier event to the logs. */
- public void writeEvent(
- @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
- Preconditions.checkNotNull(event);
- if (TcLog.ENABLE_FULL_LOGGING) {
- TcLog.v(TAG, "TextClassifierEventLogger.writeEvent: event = [" + event + "]");
- }
- if (event instanceof TextClassifierEvent.TextSelectionEvent) {
- logTextSelectionEvent(sessionId, (TextClassifierEvent.TextSelectionEvent) event);
- } else if (event instanceof TextClassifierEvent.TextLinkifyEvent) {
- logTextLinkifyEvent(sessionId, (TextClassifierEvent.TextLinkifyEvent) event);
- } else if (event instanceof TextClassifierEvent.ConversationActionsEvent) {
- logConversationActionsEvent(
- sessionId, (TextClassifierEvent.ConversationActionsEvent) event);
- } else if (event instanceof TextClassifierEvent.LanguageDetectionEvent) {
- logLanguageDetectionEvent(
- sessionId, (TextClassifierEvent.LanguageDetectionEvent) event);
- } else {
- TcLog.w(TAG, "Unexpected events, category=" + event.getEventCategory());
- }
+ /** Emits a text classifier event to the logs. */
+ public void writeEvent(
+ @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ Preconditions.checkNotNull(event);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(TAG, "TextClassifierEventLogger.writeEvent: event = [" + event + "]");
}
-
- private void logTextSelectionEvent(
- @Nullable TextClassificationSessionId sessionId,
- TextClassifierEvent.TextSelectionEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_SELECTION_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- event.getEventIndex(),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- event.getRelativeWordStartIndex(),
- event.getRelativeWordEndIndex(),
- event.getRelativeSuggestedWordStartIndex(),
- event.getRelativeSuggestedWordEndIndex(),
- getPackageName(event));
+ if (event instanceof TextClassifierEvent.TextSelectionEvent) {
+ logTextSelectionEvent(sessionId, (TextClassifierEvent.TextSelectionEvent) event);
+ } else if (event instanceof TextClassifierEvent.TextLinkifyEvent) {
+ logTextLinkifyEvent(sessionId, (TextClassifierEvent.TextLinkifyEvent) event);
+ } else if (event instanceof TextClassifierEvent.ConversationActionsEvent) {
+ logConversationActionsEvent(sessionId, (TextClassifierEvent.ConversationActionsEvent) event);
+ } else if (event instanceof TextClassifierEvent.LanguageDetectionEvent) {
+ logLanguageDetectionEvent(sessionId, (TextClassifierEvent.LanguageDetectionEvent) event);
+ } else {
+ TcLog.w(TAG, "Unexpected events, category=" + event.getEventCategory());
}
+ }
- private void logTextLinkifyEvent(
- TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- event.getEventIndex(),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- /*numOfLinks=*/ 0,
- /*linkedTextLength=*/ 0,
- /*textLength=*/ 0,
- /*latencyInMillis=*/ 0L,
- getPackageName(event));
+ private void logTextSelectionEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.TextSelectionEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_SELECTION_EVENT,
+ sessionId == null ? null : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ event.getRelativeWordStartIndex(),
+ event.getRelativeWordEndIndex(),
+ event.getRelativeSuggestedWordStartIndex(),
+ event.getRelativeSuggestedWordEndIndex(),
+ getPackageName(event));
+ }
+
+ private void logTextLinkifyEvent(
+ TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
+ sessionId == null ? null : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ /*numOfLinks=*/ 0,
+ /*linkedTextLength=*/ 0,
+ /*textLength=*/ 0,
+ /*latencyInMillis=*/ 0L,
+ getPackageName(event));
+ }
+
+ private void logConversationActionsEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.ConversationActionsEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.CONVERSATION_ACTIONS_EVENT,
+ sessionId == null
+ ? event.getResultId() // TODO: Update ExtServices to set the session id.
+ : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getItemAt(event.getEntityTypes(), /* index= */ 1),
+ getItemAt(event.getEntityTypes(), /* index= */ 2),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getPackageName(event));
+ }
+
+ private void logLanguageDetectionEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.LanguageDetectionEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.LANGUAGE_DETECTION_EVENT,
+ sessionId == null ? null : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getIntAt(event.getActionIndices(), /* index= */ 0),
+ getPackageName(event));
+ }
+
+ @Nullable
+ private static <T> T getItemAt(@Nullable T[] array, int index) {
+ if (array == null) {
+ return null;
}
-
- private void logConversationActionsEvent(
- @Nullable TextClassificationSessionId sessionId,
- TextClassifierEvent.ConversationActionsEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.CONVERSATION_ACTIONS_EVENT,
- sessionId == null
- ? event.getResultId() // TODO: Update ExtServices to set the session id.
- : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- getItemAt(event.getEntityTypes(), /* index= */ 1),
- getItemAt(event.getEntityTypes(), /* index= */ 2),
- getFloatAt(event.getScores(), /* index= */ 0),
- getPackageName(event));
+ if (index >= array.length) {
+ return null;
}
+ return array[index];
+ }
- private void logLanguageDetectionEvent(
- @Nullable TextClassificationSessionId sessionId,
- TextClassifierEvent.LanguageDetectionEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.LANGUAGE_DETECTION_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- getFloatAt(event.getScores(), /* index= */ 0),
- getIntAt(event.getActionIndices(), /* index= */ 0),
- getPackageName(event));
+ private static float getFloatAt(@Nullable float[] array, int index) {
+ if (array == null) {
+ return 0f;
}
-
- @Nullable
- private static <T> T getItemAt(@Nullable T[] array, int index) {
- if (array == null) {
- return null;
- }
- if (index >= array.length) {
- return null;
- }
- return array[index];
+ if (index >= array.length) {
+ return 0f;
}
+ return array[index];
+ }
- private static float getFloatAt(@Nullable float[] array, int index) {
- if (array == null) {
- return 0f;
- }
- if (index >= array.length) {
- return 0f;
- }
- return array[index];
+ private static int getIntAt(@Nullable int[] array, int index) {
+ if (array == null) {
+ return 0;
}
-
- private static int getIntAt(@Nullable int[] array, int index) {
- if (array == null) {
- return 0;
- }
- if (index >= array.length) {
- return 0;
- }
- return array[index];
+ if (index >= array.length) {
+ return 0;
}
+ return array[index];
+ }
- private static String getModelName(TextClassifierEvent event) {
- if (event.getModelName() != null) {
- return event.getModelName();
- }
- return ResultIdUtils.getModelName(event.getResultId());
+ private static String getModelName(TextClassifierEvent event) {
+ if (event.getModelName() != null) {
+ return event.getModelName();
}
+ return ResultIdUtils.getModelName(event.getResultId());
+ }
- @Nullable
- private static String getPackageName(TextClassifierEvent event) {
- TextClassificationContext eventContext = event.getEventContext();
- if (eventContext == null) {
- return null;
- }
- return eventContext.getPackageName();
+ @Nullable
+ private static String getPackageName(TextClassifierEvent event) {
+ TextClassificationContext eventContext = event.getEventContext();
+ if (eventContext == null) {
+ return null;
}
+ return eventContext.getPackageName();
+ }
- private static int getWidgetType(TextClassifierEvent event) {
- TextClassificationContext eventContext = event.getEventContext();
- if (eventContext == null) {
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- }
- switch (eventContext.getWidgetType()) {
- case TextClassifier.WIDGET_TYPE_UNKNOWN:
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- case TextClassifier.WIDGET_TYPE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_EDITTEXT:
- return WidgetType.WIDGET_TYPE_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_WEBVIEW:
- return WidgetType.WIDGET_TYPE_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
- return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
- return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_NOTIFICATION:
- return WidgetType.WIDGET_TYPE_NOTIFICATION;
- }
+ private static int getWidgetType(TextClassifierEvent event) {
+ TextClassificationContext eventContext = event.getEventContext();
+ if (eventContext == null) {
+ return WidgetType.WIDGET_TYPE_UNKNOWN;
+ }
+ switch (eventContext.getWidgetType()) {
+ case TextClassifier.WIDGET_TYPE_UNKNOWN:
return WidgetType.WIDGET_TYPE_UNKNOWN;
+ case TextClassifier.WIDGET_TYPE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_EDITTEXT:
+ return WidgetType.WIDGET_TYPE_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_WEBVIEW:
+ return WidgetType.WIDGET_TYPE_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
+ return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
+ return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_NOTIFICATION:
+ return WidgetType.WIDGET_TYPE_NOTIFICATION;
+ default: // fall out
}
+ return WidgetType.WIDGET_TYPE_UNKNOWN;
+ }
- /** Widget type constants for logging. */
- public interface WidgetType {
- // Sync these constants with textclassifier_enums.proto.
- int WIDGET_TYPE_UNKNOWN = 0;
- int WIDGET_TYPE_TEXTVIEW = 1;
- int WIDGET_TYPE_EDITTEXT = 2;
- int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3;
- int WIDGET_TYPE_WEBVIEW = 4;
- int WIDGET_TYPE_EDIT_WEBVIEW = 5;
- int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6;
- int WIDGET_TYPE_CUSTOM_EDITTEXT = 7;
- int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8;
- int WIDGET_TYPE_NOTIFICATION = 9;
- }
+ /** Widget type constants for logging. */
+ public static final class WidgetType {
+ // Sync these constants with textclassifier_enums.proto.
+ public static final int WIDGET_TYPE_UNKNOWN = 0;
+ public static final int WIDGET_TYPE_TEXTVIEW = 1;
+ public static final int WIDGET_TYPE_EDITTEXT = 2;
+ public static final int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3;
+ public static final int WIDGET_TYPE_WEBVIEW = 4;
+ public static final int WIDGET_TYPE_EDIT_WEBVIEW = 5;
+ public static final int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6;
+ public static final int WIDGET_TYPE_CUSTOM_EDITTEXT = 7;
+ public static final int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8;
+ public static final int WIDGET_TYPE_NOTIFICATION = 9;
+
+ private WidgetType() {}
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java
index 64a9fe8..66db748 100644
--- a/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,14 +19,11 @@
import android.content.Context;
import android.util.ArrayMap;
import android.view.textclassifier.TextClassifierEvent;
-
import com.android.textclassifier.TextClassificationConstants;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -42,77 +39,74 @@
*/
final class BasicLanguageProficiencyAnalyzer implements LanguageProficiencyAnalyzer {
- private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
- TimeUnit.HOURS.toMillis(6);
+ private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
+ TimeUnit.HOURS.toMillis(6);
- private final TextClassificationConstants mSettings;
- private final LanguageProfileDatabase mDatabase;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final TextClassificationConstants settings;
+ private final LanguageProfileDatabase database;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- private Map<String, Float> mCanUnderstandResultCache;
- private long mCanUnderstandResultCacheTime;
+ private Map<String, Float> canUnderstandResultCache;
+ private long canUnderstandResultCacheTime;
- BasicLanguageProficiencyAnalyzer(
- Context context,
- TextClassificationConstants settings,
- SystemLanguagesProvider systemLanguagesProvider) {
- this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ BasicLanguageProficiencyAnalyzer(
+ Context context,
+ TextClassificationConstants settings,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ }
+
+ @VisibleForTesting
+ BasicLanguageProficiencyAnalyzer(
+ TextClassificationConstants settings,
+ LanguageProfileDatabase languageProfileDatabase,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this.settings = Preconditions.checkNotNull(settings);
+ database = Preconditions.checkNotNull(languageProfileDatabase);
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ canUnderstandResultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public synchronized float canUnderstand(String languageTag) {
+ if (canUnderstandResultCache.isEmpty()
+ || (System.currentTimeMillis() - canUnderstandResultCacheTime)
+ >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
+ canUnderstandResultCache = createCanUnderstandResultCache();
+ canUnderstandResultCacheTime = System.currentTimeMillis();
}
+ return canUnderstandResultCache.getOrDefault(languageTag, 0f);
+ }
- @VisibleForTesting
- BasicLanguageProficiencyAnalyzer(
- TextClassificationConstants settings,
- LanguageProfileDatabase languageProfileDatabase,
- SystemLanguagesProvider systemLanguagesProvider) {
- mSettings = Preconditions.checkNotNull(settings);
- mDatabase = Preconditions.checkNotNull(languageProfileDatabase);
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mCanUnderstandResultCache = new ArrayMap<>();
+ @Override
+ public void onTextClassifierEvent(TextClassifierEvent event) {}
+
+ @Override
+ public boolean shouldShowTranslation(String languageCode) {
+ return canUnderstand(languageCode) >= settings.getTranslateActionThreshold();
+ }
+
+ private Map<String, Float> createCanUnderstandResultCache() {
+ Map<String, Float> result = new ArrayMap<>();
+ List<String> systemLanguageTags = systemLanguagesProvider.getSystemLanguageTags();
+ List<LanguageSignalInfo> languageSignalInfos =
+ database.languageInfoDao().getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
+ // Applies system languages to bootstrap the model according to Zipf's Law.
+ // Zipf’s Law states that the ith most common language should be proportional to 1/i.
+ for (int i = 0; i < systemLanguageTags.size(); i++) {
+ String languageTag = systemLanguageTags.get(i);
+ result.put(
+ languageTag, settings.getLanguageProficiencyBootstrappingCount() / (float) (i + 1));
}
-
- @Override
- public synchronized float canUnderstand(String languageTag) {
- if (mCanUnderstandResultCache.isEmpty()
- || (System.currentTimeMillis() - mCanUnderstandResultCacheTime)
- >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
- mCanUnderstandResultCache = createCanUnderstandResultCache();
- mCanUnderstandResultCacheTime = System.currentTimeMillis();
- }
- return mCanUnderstandResultCache.getOrDefault(languageTag, 0f);
+ // Adds message counts of different languages into the corresponding entry in the map
+ for (LanguageSignalInfo info : languageSignalInfos) {
+ String languageTag = info.getLanguageTag();
+ int count = info.getCount();
+ result.put(languageTag, result.getOrDefault(languageTag, 0f) + count);
}
-
- @Override
- public void onTextClassifierEvent(TextClassifierEvent event) {}
-
- @Override
- public boolean shouldShowTranslation(String languageCode) {
- return canUnderstand(languageCode) >= mSettings.getTranslateActionThreshold();
- }
-
- private Map<String, Float> createCanUnderstandResultCache() {
- Map<String, Float> result = new ArrayMap<>();
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
- List<LanguageSignalInfo> languageSignalInfos =
- mDatabase
- .languageInfoDao()
- .getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
- // Applies system languages to bootstrap the model according to Zipf's Law.
- // Zipf’s Law states that the ith most common language should be proportional to 1/i.
- for (int i = 0; i < systemLanguageTags.size(); i++) {
- String languageTag = systemLanguageTags.get(i);
- result.put(
- languageTag,
- mSettings.getLanguageProficiencyBootstrappingCount() / (float) (i + 1));
- }
- // Adds message counts of different languages into the corresponding entry in the map
- for (LanguageSignalInfo info : languageSignalInfos) {
- String languageTag = info.getLanguageTag();
- int count = info.getCount();
- result.put(languageTag, result.getOrDefault(languageTag, 0f) + count);
- }
- // Calculates confidence scores
- float max = Collections.max(result.values());
- result.forEach((languageTag, count) -> result.put(languageTag, count / max));
- return result;
- }
+ // Calculates confidence scores
+ float max = Collections.max(result.values());
+ result.forEach((languageTag, count) -> result.put(languageTag, count / max));
+ return result;
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java
index e33e45b..28d97e4 100644
--- a/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,17 +18,13 @@
import android.content.Context;
import android.view.textclassifier.TextClassifierEvent;
-
import androidx.collection.ArrayMap;
-
import com.android.textclassifier.TextClassificationConstants;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
import com.android.textclassifier.ulp.kmeans.KMeans;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -45,127 +41,126 @@
// STOPSHIP: Review the entire ULP package before shipping it.
final class KmeansLanguageProficiencyAnalyzer implements LanguageProficiencyAnalyzer {
- private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
- TimeUnit.HOURS.toMillis(6);
+ private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
+ TimeUnit.HOURS.toMillis(6);
- private final TextClassificationConstants mSettings;
- private final LanguageProfileDatabase mDatabase;
- private final KMeans mKmeans;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final TextClassificationConstants settings;
+ private final LanguageProfileDatabase database;
+ private final KMeans kmeans;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- private Map<String, Float> mCanUnderstandResultCache;
- private long mCanUnderstandResultCacheTime;
+ private Map<String, Float> canUnderstandResultCache;
+ private long canUnderstandResultCacheTime;
- KmeansLanguageProficiencyAnalyzer(
- Context context,
- TextClassificationConstants settings,
- SystemLanguagesProvider systemLanguagesProvider) {
- this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ KmeansLanguageProficiencyAnalyzer(
+ Context context,
+ TextClassificationConstants settings,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ }
+
+ @VisibleForTesting
+ KmeansLanguageProficiencyAnalyzer(
+ TextClassificationConstants settings,
+ LanguageProfileDatabase languageProfileDatabase,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this.settings = Preconditions.checkNotNull(settings);
+ database = Preconditions.checkNotNull(languageProfileDatabase);
+ kmeans = new KMeans();
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ canUnderstandResultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public synchronized float canUnderstand(String languageTag) {
+ if (canUnderstandResultCache.isEmpty()
+ || (System.currentTimeMillis() - canUnderstandResultCacheTime)
+ >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
+ canUnderstandResultCache = createCanUnderstandResultCache();
+ canUnderstandResultCacheTime = System.currentTimeMillis();
}
+ return canUnderstandResultCache.getOrDefault(languageTag, 0f);
+ }
- @VisibleForTesting
- KmeansLanguageProficiencyAnalyzer(
- TextClassificationConstants settings,
- LanguageProfileDatabase languageProfileDatabase,
- SystemLanguagesProvider systemLanguagesProvider) {
- mSettings = Preconditions.checkNotNull(settings);
- mDatabase = Preconditions.checkNotNull(languageProfileDatabase);
- mKmeans = new KMeans();
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mCanUnderstandResultCache = new ArrayMap<>();
+ private Map<String, Float> createCanUnderstandResultCache() {
+ Map<String, Float> result = new ArrayMap<>();
+ ArrayMap<String, Integer> languageCounts = new ArrayMap<>();
+ List<String> systemLanguageTags = systemLanguagesProvider.getSystemLanguageTags();
+ List<LanguageSignalInfo> languageSignalInfos =
+ database.languageInfoDao().getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
+ // Applies system languages to bootstrap the model according to Zipf's Law.
+ // Zipf’s Law states that the ith most common language should be proportional to 1/i.
+ for (int i = 0; i < systemLanguageTags.size(); i++) {
+ String languageTag = systemLanguageTags.get(i);
+ languageCounts.put(
+ languageTag, settings.getLanguageProficiencyBootstrappingCount() / (i + 1));
}
-
- @Override
- public synchronized float canUnderstand(String languageTag) {
- if (mCanUnderstandResultCache.isEmpty()
- || (System.currentTimeMillis() - mCanUnderstandResultCacheTime)
- >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
- mCanUnderstandResultCache = createCanUnderstandResultCache();
- mCanUnderstandResultCacheTime = System.currentTimeMillis();
- }
- return mCanUnderstandResultCache.getOrDefault(languageTag, 0f);
+ // Adds message counts of different languages into the corresponding entry in the map
+ for (LanguageSignalInfo info : languageSignalInfos) {
+ String languageTag = info.getLanguageTag();
+ int count = info.getCount();
+ languageCounts.put(languageTag, languageCounts.getOrDefault(languageTag, 0) + count);
}
-
- private Map<String, Float> createCanUnderstandResultCache() {
- Map<String, Float> result = new ArrayMap<>();
- ArrayMap<String, Integer> languageCounts = new ArrayMap<>();
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
- List<LanguageSignalInfo> languageSignalInfos =
- mDatabase
- .languageInfoDao()
- .getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
- // Applies system languages to bootstrap the model according to Zipf's Law.
- // Zipf’s Law states that the ith most common language should be proportional to 1/i.
- for (int i = 0; i < systemLanguageTags.size(); i++) {
- String languageTag = systemLanguageTags.get(i);
- languageCounts.put(
- languageTag, mSettings.getLanguageProficiencyBootstrappingCount() / (i + 1));
- }
- // Adds message counts of different languages into the corresponding entry in the map
- for (LanguageSignalInfo info : languageSignalInfos) {
- String languageTag = info.getLanguageTag();
- int count = info.getCount();
- languageCounts.put(languageTag, languageCounts.getOrDefault(languageTag, 0) + count);
- }
- // Calculates confidence scores
- if (languageCounts.size() == 1) {
- result.put(languageCounts.keyAt(0), 1f);
- return result;
- }
- if (languageCounts.size() == 2) {
- return evaluateTwoLanguageCounts(languageCounts);
- }
- // Applies K-Means to cluster data points
- int size = languageCounts.size();
- float[][] inputData = new float[size][1];
- for (int i = 0; i < size; i++) {
- inputData[i][0] = languageCounts.valueAt(i);
- }
- List<KMeans.Mean> means = mKmeans.predict(/* k= */ 2, inputData);
- List<Integer> countsInMaxCluster = getCountsWithinFarthestCluster(means);
- for (int i = 0; i < languageCounts.size(); i++) {
- float score = countsInMaxCluster.contains(languageCounts.valueAt(i)) ? 1f : 0f;
- result.put(languageCounts.keyAt(i), score);
- }
- return result;
+ // Calculates confidence scores
+ if (languageCounts.size() == 1) {
+ result.put(languageCounts.keyAt(0), 1f);
+ return result;
}
-
- @Override
- public void onTextClassifierEvent(TextClassifierEvent event) {}
-
- @Override
- public boolean shouldShowTranslation(String languageCode) {
- return canUnderstand(languageCode) >= mSettings.getTranslateActionThreshold();
+ if (languageCounts.size() == 2) {
+ return evaluateTwoLanguageCounts(languageCounts);
}
-
- private Map<String, Float> evaluateTwoLanguageCounts(ArrayMap<String, Integer> languageCounts) {
- Map<String, Float> result = new ArrayMap<>();
- int countOne = languageCounts.valueAt(0);
- String languageTagOne = languageCounts.keyAt(0);
- int countTwo = languageCounts.valueAt(1);
- String languageTagTwo = languageCounts.keyAt(1);
- if (countOne >= countTwo) {
- result.put(languageTagOne, 1f);
- result.put(languageTagTwo, countTwo / (float) countOne);
- } else {
- result.put(languageTagTwo, 1f);
- result.put(languageTagOne, countOne / (float) countTwo);
- }
- return result;
+ // Applies K-Means to cluster data points
+ int size = languageCounts.size();
+ float[][] inputData = new float[size][1];
+ for (int i = 0; i < size; i++) {
+ inputData[i][0] = languageCounts.valueAt(i);
}
-
- private List<Integer> getCountsWithinFarthestCluster(List<KMeans.Mean> means) {
- List<Integer> result = new ArrayList<>();
- KMeans.Mean farthestMean = means.get(0);
- for (int i = 1; i < means.size(); i++) {
- KMeans.Mean curMean = means.get(i);
- if (curMean.getCentroid()[0] > farthestMean.getCentroid()[0]) {
- farthestMean = curMean;
- }
- }
- for (float[] item : farthestMean.getItems()) {
- result.add((int) item[0]);
- }
- return result;
+ List<KMeans.Mean> means = kmeans.predict(/* k= */ 2, inputData);
+ List<Integer> countsInMaxCluster = getCountsWithinFarthestCluster(means);
+ for (int i = 0; i < languageCounts.size(); i++) {
+ float score = countsInMaxCluster.contains(languageCounts.valueAt(i)) ? 1f : 0f;
+ result.put(languageCounts.keyAt(i), score);
}
+ return result;
+ }
+
+ @Override
+ public void onTextClassifierEvent(TextClassifierEvent event) {}
+
+ @Override
+ public boolean shouldShowTranslation(String languageCode) {
+ return canUnderstand(languageCode) >= settings.getTranslateActionThreshold();
+ }
+
+ private static Map<String, Float> evaluateTwoLanguageCounts(
+ ArrayMap<String, Integer> languageCounts) {
+ Map<String, Float> result = new ArrayMap<>();
+ int countOne = languageCounts.valueAt(0);
+ String languageTagOne = languageCounts.keyAt(0);
+ int countTwo = languageCounts.valueAt(1);
+ String languageTagTwo = languageCounts.keyAt(1);
+ if (countOne >= countTwo) {
+ result.put(languageTagOne, 1f);
+ result.put(languageTagTwo, countTwo / (float) countOne);
+ } else {
+ result.put(languageTagTwo, 1f);
+ result.put(languageTagOne, countOne / (float) countTwo);
+ }
+ return result;
+ }
+
+ private static List<Integer> getCountsWithinFarthestCluster(List<KMeans.Mean> means) {
+ List<Integer> result = new ArrayList<>();
+ KMeans.Mean farthestMean = means.get(0);
+ for (int i = 1; i < means.size(); i++) {
+ KMeans.Mean curMean = means.get(i);
+ if (curMean.getCentroid()[0] > farthestMean.getCentroid()[0]) {
+ farthestMean = curMean;
+ }
+ }
+ for (float[] item : farthestMean.getItems()) {
+ result.add((int) item[0]);
+ }
+ return result;
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java
index 89ea236..c9e8695 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,14 +17,13 @@
package com.android.textclassifier.ulp;
import android.view.textclassifier.TextClassifierEvent;
-
import androidx.annotation.FloatRange;
interface LanguageProficiencyAnalyzer {
- @FloatRange(from = 0.0, to = 1.0)
- float canUnderstand(String languageCode);
+ @FloatRange(from = 0.0, to = 1.0)
+ float canUnderstand(String languageCode);
- void onTextClassifierEvent(TextClassifierEvent event);
+ void onTextClassifierEvent(TextClassifierEvent event);
- boolean shouldShowTranslation(String languageCode);
+ boolean shouldShowTranslation(String languageCode);
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java b/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java
index c8ed732..7c16418 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,106 +18,104 @@
import androidx.collection.ArrayMap;
import androidx.collection.ArraySet;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.util.Set;
final class LanguageProficiencyEvaluator {
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- LanguageProficiencyEvaluator(SystemLanguagesProvider systemLanguagesProvider) {
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ LanguageProficiencyEvaluator(SystemLanguagesProvider systemLanguagesProvider) {
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ }
+
+ EvaluationResult evaluate(LanguageProficiencyAnalyzer analyzer, Set<String> languagesToEvaluate) {
+ Set<String> systemLanguageTags =
+ new ArraySet<>(systemLanguagesProvider.getSystemLanguageTags());
+ ArrayMap<String, Boolean> actual = new ArrayMap<>();
+ // We assume user can only speak the languages that are set as system languages.
+ for (String languageToEvaluate : languagesToEvaluate) {
+ actual.put(languageToEvaluate, systemLanguageTags.contains(languageToEvaluate));
+ }
+ return evaluateWithActual(analyzer, actual);
+ }
+
+ private static EvaluationResult evaluateWithActual(
+ LanguageProficiencyAnalyzer analyzer, ArrayMap<String, Boolean> actual) {
+ ArrayMap<String, Boolean> predict = new ArrayMap<>();
+ for (int i = 0; i < actual.size(); i++) {
+ String languageTag = actual.keyAt(i);
+ predict.put(languageTag, analyzer.canUnderstand(languageTag) >= 0.5f);
+ }
+ return EvaluationResult.create(actual, predict);
+ }
+
+ static final class EvaluationResult {
+ final int truePositive;
+ final int trueNegative;
+ final int falsePositive;
+ final int falseNegative;
+
+ private EvaluationResult(
+ int truePositive, int trueNegative, int falsePositive, int falseNegative) {
+ this.truePositive = truePositive;
+ this.trueNegative = trueNegative;
+ this.falsePositive = falsePositive;
+ this.falseNegative = falseNegative;
}
- EvaluationResult evaluate(
- LanguageProficiencyAnalyzer analyzer, Set<String> languagesToEvaluate) {
- Set<String> systemLanguageTags =
- new ArraySet<>(mSystemLanguagesProvider.getSystemLanguageTags());
- ArrayMap<String, Boolean> actual = new ArrayMap<>();
- // We assume user can only speak the languages that are set as system languages.
- for (String languageToEvaluate : languagesToEvaluate) {
- actual.put(languageToEvaluate, systemLanguageTags.contains(languageToEvaluate));
- }
- return evaluateWithActual(analyzer, actual);
+ float computePrecisionOfPositiveClass() {
+ float divisor = truePositive + falsePositive;
+ return divisor != 0 ? truePositive / divisor : 1f;
}
- private EvaluationResult evaluateWithActual(
- LanguageProficiencyAnalyzer analyzer, ArrayMap<String, Boolean> actual) {
- ArrayMap<String, Boolean> predict = new ArrayMap<>();
- for (int i = 0; i < actual.size(); i++) {
- String languageTag = actual.keyAt(i);
- predict.put(languageTag, analyzer.canUnderstand(languageTag) >= 0.5f);
- }
- return EvaluationResult.create(actual, predict);
+ float computePrecisionOfNegativeClass() {
+ float divisor = trueNegative + falseNegative;
+ return divisor != 0 ? trueNegative / divisor : 1f;
}
- static final class EvaluationResult {
- final int truePositive;
- final int trueNegative;
- final int falsePositive;
- final int falseNegative;
-
- private EvaluationResult(
- int truePositive, int trueNegative, int falsePositive, int falseNegative) {
- this.truePositive = truePositive;
- this.trueNegative = trueNegative;
- this.falsePositive = falsePositive;
- this.falseNegative = falseNegative;
- }
-
- float computePrecisionOfPositiveClass() {
- float divisor = truePositive + falsePositive;
- return divisor != 0 ? truePositive / divisor : 1f;
- }
-
- float computePrecisionOfNegativeClass() {
- float divisor = trueNegative + falseNegative;
- return divisor != 0 ? trueNegative / divisor : 1f;
- }
-
- float computeRecallOfPositiveClass() {
- float divisor = truePositive + falseNegative;
- return divisor != 0 ? truePositive / divisor : 1f;
- }
-
- float computeRecallOfNegativeClass() {
- float divisor = trueNegative + falsePositive;
- return divisor != 0 ? trueNegative / divisor : 1f;
- }
-
- float computeF1ScoreOfPositiveClass() {
- return 2 * truePositive / (float) (2 * truePositive + falsePositive + falseNegative);
- }
-
- float computeF1ScoreOfNegativeClass() {
- return 2 * trueNegative / (float) (2 * trueNegative + falsePositive + falseNegative);
- }
-
- static EvaluationResult create(
- ArrayMap<String, Boolean> actual, ArrayMap<String, Boolean> predict) {
- int truePositive = 0;
- int trueNegative = 0;
- int falsePositive = 0;
- int falseNegative = 0;
- for (int i = 0; i < actual.size(); i++) {
- String languageTag = actual.keyAt(i);
- boolean actualLabel = actual.valueAt(i);
- boolean predictLabel = predict.get(languageTag);
- if (predictLabel) {
- if (actualLabel == predictLabel) {
- truePositive += 1;
- } else {
- falsePositive += 1;
- }
- } else {
- if (actualLabel == predictLabel) {
- trueNegative += 1;
- } else {
- falseNegative += 1;
- }
- }
- }
- return new EvaluationResult(truePositive, trueNegative, falsePositive, falseNegative);
- }
+ float computeRecallOfPositiveClass() {
+ float divisor = truePositive + falseNegative;
+ return divisor != 0 ? truePositive / divisor : 1f;
}
+
+ float computeRecallOfNegativeClass() {
+ float divisor = trueNegative + falsePositive;
+ return divisor != 0 ? trueNegative / divisor : 1f;
+ }
+
+ float computeF1ScoreOfPositiveClass() {
+ return 2 * truePositive / (float) (2 * truePositive + falsePositive + falseNegative);
+ }
+
+ float computeF1ScoreOfNegativeClass() {
+ return 2 * trueNegative / (float) (2 * trueNegative + falsePositive + falseNegative);
+ }
+
+ static EvaluationResult create(
+ ArrayMap<String, Boolean> actual, ArrayMap<String, Boolean> predict) {
+ int truePositive = 0;
+ int trueNegative = 0;
+ int falsePositive = 0;
+ int falseNegative = 0;
+ for (int i = 0; i < actual.size(); i++) {
+ String languageTag = actual.keyAt(i);
+ boolean actualLabel = actual.valueAt(i);
+ boolean predictLabel = predict.get(languageTag);
+ if (predictLabel) {
+ if (actualLabel == predictLabel) {
+ truePositive += 1;
+ } else {
+ falsePositive += 1;
+ }
+ } else {
+ if (actualLabel == predictLabel) {
+ trueNegative += 1;
+ } else {
+ falseNegative += 1;
+ }
+ }
+ }
+ return new EvaluationResult(truePositive, trueNegative, falsePositive, falseNegative);
+ }
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java b/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java
index ac220f3..4c436dc 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,21 +19,17 @@
import android.content.Context;
import android.util.ArrayMap;
import android.view.textclassifier.TextClassifierEvent;
-
import androidx.annotation.FloatRange;
-import androidx.annotation.NonNull;
-
import com.android.textclassifier.Entity;
import com.android.textclassifier.TcLog;
import com.android.textclassifier.TextClassificationConstants;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
import com.android.textclassifier.utils.IndentingPrintWriter;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
-
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -51,188 +47,177 @@
* blocking operations and should be called on the worker thread.
*/
public class LanguageProfileAnalyzer {
- private final Context mContext;
- private final TextClassificationConstants mTextClassificationConstants;
- private final LanguageProfileDatabase mLanguageProfileDatabase;
- private final LanguageProficiencyAnalyzer mProficiencyAnalyzer;
- private final LocationSignalProvider mLocationSignalProvider;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final Context context;
+ private final TextClassificationConstants textClassificationConstants;
+ private final LanguageProfileDatabase languageProfileDatabase;
+ private final LanguageProficiencyAnalyzer proficiencyAnalyzer;
+ private final LocationSignalProvider locationSignalProvider;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- @VisibleForTesting
- LanguageProfileAnalyzer(
- Context context,
- TextClassificationConstants textClassificationConstants,
- LanguageProfileDatabase database,
- LanguageProficiencyAnalyzer languageProficiencyAnalyzer,
- LocationSignalProvider locationSignalProvider,
- SystemLanguagesProvider systemLanguagesProvider) {
- mContext = context;
- mTextClassificationConstants = textClassificationConstants;
- mLanguageProfileDatabase = Preconditions.checkNotNull(database);
- mProficiencyAnalyzer = Preconditions.checkNotNull(languageProficiencyAnalyzer);
- mLocationSignalProvider = Preconditions.checkNotNull(locationSignalProvider);
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ @VisibleForTesting
+ LanguageProfileAnalyzer(
+ Context context,
+ TextClassificationConstants textClassificationConstants,
+ LanguageProfileDatabase database,
+ LanguageProficiencyAnalyzer languageProficiencyAnalyzer,
+ LocationSignalProvider locationSignalProvider,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this.context = context;
+ this.textClassificationConstants = textClassificationConstants;
+ languageProfileDatabase = Preconditions.checkNotNull(database);
+ proficiencyAnalyzer = Preconditions.checkNotNull(languageProficiencyAnalyzer);
+ this.locationSignalProvider = Preconditions.checkNotNull(locationSignalProvider);
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ }
+
+ /** Creates an instance of {@link LanguageProfileAnalyzer}. */
+ public static LanguageProfileAnalyzer create(
+ Context context, TextClassificationConstants textClassificationConstants) {
+ SystemLanguagesProvider systemLanguagesProvider = new SystemLanguagesProvider();
+ LocationSignalProvider locationSignalProvider = new LocationSignalProvider(context);
+ return new LanguageProfileAnalyzer(
+ context,
+ textClassificationConstants,
+ LanguageProfileDatabase.getInstance(context),
+ new ReinforcementLanguageProficiencyAnalyzer(context, systemLanguagesProvider),
+ locationSignalProvider,
+ systemLanguagesProvider);
+ }
+
+ /**
+ * Returns the confidence score for which the user understands the given language. The result is
+ * recalculated every constant time.
+ *
+ * <p>The score ranges from 0 to 1. 1 indicates the language is very familiar to the user and vice
+ * versa.
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public float canUnderstand(String languageTag) {
+ return proficiencyAnalyzer.canUnderstand(languageTag);
+ }
+
+ /** Decides whether we should show translation for that language or no. */
+ public boolean shouldShowTranslation(String languageTag) {
+ return proficiencyAnalyzer.shouldShowTranslation(languageTag);
+ }
+
+ /** Performs actions defined for specific TextClassification events. */
+ public void onTextClassifierEven(TextClassifierEvent event) {
+ proficiencyAnalyzer.onTextClassifierEvent(event);
+ }
+
+ /**
+ * Returns a list of languages that appear in the specified source, the list is sorted by the
+ * frequency descendingly. The confidence score represents how frequent of the language is,
+ * compared to the most frequent language.
+ */
+ public List<Entity> getFrequentLanguages(@LanguageSignalInfo.Source int source) {
+ List<LanguageSignalInfo> languageSignalInfos =
+ languageProfileDatabase.languageInfoDao().getBySource(source);
+ int bootstrappingCount = textClassificationConstants.getFrequentLanguagesBootstrappingCount();
+ ArrayMap<String, Integer> languageCountMap = new ArrayMap<>();
+ systemLanguagesProvider
+ .getSystemLanguageTags()
+ .forEach(lang -> languageCountMap.put(lang, bootstrappingCount));
+ String languageTagFromLocation = locationSignalProvider.detectLanguageTag();
+ if (languageTagFromLocation != null) {
+ languageCountMap.put(
+ languageTagFromLocation,
+ languageCountMap.getOrDefault(languageTagFromLocation, 0) + bootstrappingCount);
}
-
- /** Creates an instance of {@link LanguageProfileAnalyzer}. */
- public static LanguageProfileAnalyzer create(
- Context context, TextClassificationConstants textClassificationConstants) {
- SystemLanguagesProvider systemLanguagesProvider = new SystemLanguagesProvider();
- LocationSignalProvider locationSignalProvider = new LocationSignalProvider(context);
- return new LanguageProfileAnalyzer(
- context,
- textClassificationConstants,
- LanguageProfileDatabase.getInstance(context),
- new ReinforcementLanguageProficiencyAnalyzer(context, systemLanguagesProvider),
- locationSignalProvider,
- systemLanguagesProvider);
+ for (LanguageSignalInfo languageSignalInfo : languageSignalInfos) {
+ String lang = languageSignalInfo.getLanguageTag();
+ languageCountMap.put(
+ lang, languageSignalInfo.getCount() + languageCountMap.getOrDefault(lang, 0));
}
-
- /**
- * Returns the confidence score for which the user understands the given language. The result is
- * recalculated every constant time.
- *
- * <p>The score ranges from 0 to 1. 1 indicates the language is very familiar to the user and
- * vice versa.
- */
- @FloatRange(from = 0.0, to = 1.0)
- public float canUnderstand(String languageTag) {
- return mProficiencyAnalyzer.canUnderstand(languageTag);
+ int max = Collections.max(languageCountMap.values());
+ if (max == 0) {
+ return ImmutableList.of();
}
-
- /** Decides whether we should show translation for that language or no. */
- public boolean shouldShowTranslation(String languageTag) {
- return mProficiencyAnalyzer.shouldShowTranslation(languageTag);
+ List<Entity> frequentLanguages = new ArrayList<>();
+ for (int i = 0; i < languageCountMap.size(); i++) {
+ String lang = languageCountMap.keyAt(i);
+ float score = languageCountMap.valueAt(i) / (float) max;
+ frequentLanguages.add(new Entity(lang, score));
}
+ Collections.sort(frequentLanguages);
+ return ImmutableList.copyOf(frequentLanguages);
+ }
- /** Performs actions defined for specific TextClassification events. */
- public void onTextClassifierEven(TextClassifierEvent event) {
- mProficiencyAnalyzer.onTextClassifierEvent(event);
+ /** Dumps the data on the screen when called. */
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("LanguageProfileAnalyzer:");
+ printWriter.increaseIndent();
+ printWriter.printPair(
+ "System languages", String.join(",", systemLanguagesProvider.getSystemLanguageTags()));
+ printWriter.printPair(
+ "Language code deduced from location", locationSignalProvider.detectLanguageTag());
+
+ ExecutorService executorService =
+ MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor());
+ try {
+ executorService
+ .submit(
+ () -> {
+ printWriter.println("Languages that user has seen in selections:");
+ dumpFrequentLanguages(printWriter, LanguageSignalInfo.CLASSIFY_TEXT);
+
+ printWriter.println("Languages that user has seen in message notifications:");
+ dumpFrequentLanguages(printWriter, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
+
+ dumpEvaluationReport(printWriter);
+ })
+ .get();
+ } catch (ExecutionException | InterruptedException e) {
+ TcLog.e(TcLog.TAG, "Dumping interrupted: ", e);
}
+ printWriter.decreaseIndent();
+ }
- /**
- * Returns a list of languages that appear in the specified source, the list is sorted by the
- * frequency descendingly. The confidence score represents how frequent of the language is,
- * compared to the most frequent language.
- */
- @NonNull
- public List<Entity> getFrequentLanguages(@LanguageSignalInfo.Source int source) {
- List<LanguageSignalInfo> languageSignalInfos =
- mLanguageProfileDatabase.languageInfoDao().getBySource(source);
- int bootstrappingCount =
- mTextClassificationConstants.getFrequentLanguagesBootstrappingCount();
- ArrayMap<String, Integer> languageCountMap = new ArrayMap<>();
- mSystemLanguagesProvider
- .getSystemLanguageTags()
- .forEach(lang -> languageCountMap.put(lang, bootstrappingCount));
- String languageTagFromLocation = mLocationSignalProvider.detectLanguageTag();
- if (languageTagFromLocation != null) {
- languageCountMap.put(
- languageTagFromLocation,
- languageCountMap.getOrDefault(languageTagFromLocation, 0) + bootstrappingCount);
- }
- for (LanguageSignalInfo languageSignalInfo : languageSignalInfos) {
- String lang = languageSignalInfo.getLanguageTag();
- languageCountMap.put(
- lang, languageSignalInfo.getCount() + languageCountMap.getOrDefault(lang, 0));
- }
- int max = Collections.max(languageCountMap.values());
- if (max == 0) {
- return Collections.emptyList();
- }
- List<Entity> frequentLanguages = new ArrayList<>();
- for (int i = 0; i < languageCountMap.size(); i++) {
- String lang = languageCountMap.keyAt(i);
- float score = languageCountMap.valueAt(i) / (float) max;
- frequentLanguages.add(new Entity(lang, score));
- }
- Collections.sort(frequentLanguages);
- return frequentLanguages;
+ private void dumpEvaluationReport(IndentingPrintWriter printWriter) {
+ List<String> systemLanguageTags = systemLanguagesProvider.getSystemLanguageTags();
+ if (systemLanguageTags.size() <= 1) {
+ printWriter.println("Skipped evaluation as there are less than two system languages.");
+ return;
}
-
- /** Dumps the data on the screen when called. */
- public void dump(IndentingPrintWriter printWriter) {
- printWriter.println("LanguageProfileAnalyzer:");
- printWriter.increaseIndent();
- printWriter.printPair(
- "System languages",
- String.join(",", mSystemLanguagesProvider.getSystemLanguageTags()));
- printWriter.printPair(
- "Language code deduced from location", mLocationSignalProvider.detectLanguageTag());
-
- ExecutorService executorService =
- MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor());
- try {
- executorService
- .submit(
- () -> {
- printWriter.println("Languages that user has seen in selections:");
- dumpFrequentLanguages(
- printWriter, LanguageSignalInfo.CLASSIFY_TEXT);
-
- printWriter.println(
- "Languages that user has seen in message notifications:");
- dumpFrequentLanguages(
- printWriter,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
-
- dumpEvaluationReport(printWriter);
- })
- .get();
- } catch (ExecutionException | InterruptedException e) {
- TcLog.e(TcLog.TAG, "Dumping interrupted: ", e);
- }
- printWriter.decreaseIndent();
+ Set<String> languagesToEvaluate =
+ languageProfileDatabase.languageInfoDao().getAll().stream()
+ .map(LanguageSignalInfo::getLanguageTag)
+ .collect(Collectors.toSet());
+ languagesToEvaluate.addAll(systemLanguageTags);
+ LanguageProficiencyEvaluator evaluator =
+ new LanguageProficiencyEvaluator(systemLanguagesProvider);
+ LanguageProficiencyAnalyzer[] analyzers =
+ new LanguageProficiencyAnalyzer[] {
+ new BasicLanguageProficiencyAnalyzer(
+ context, textClassificationConstants, systemLanguagesProvider),
+ new KmeansLanguageProficiencyAnalyzer(
+ context, textClassificationConstants, systemLanguagesProvider),
+ proficiencyAnalyzer
+ };
+ for (LanguageProficiencyAnalyzer analyzer : analyzers) {
+ LanguageProficiencyEvaluator.EvaluationResult result =
+ evaluator.evaluate(analyzer, languagesToEvaluate);
+ printWriter.println("Evaluation result of " + analyzer.getClass().getSimpleName());
+ printWriter.increaseIndent();
+ printWriter.printPair(
+ "Precision of positive class", result.computePrecisionOfPositiveClass());
+ printWriter.printPair(
+ "Precision of negative class", result.computePrecisionOfNegativeClass());
+ printWriter.printPair("Recall of positive class", result.computeRecallOfPositiveClass());
+ printWriter.printPair("Recall of negative class", result.computeRecallOfNegativeClass());
+ printWriter.printPair("F1 score of positive class", result.computeF1ScoreOfPositiveClass());
+ printWriter.printPair("F1 score of negative class", result.computeF1ScoreOfNegativeClass());
+ printWriter.decreaseIndent();
}
+ }
- private void dumpEvaluationReport(IndentingPrintWriter printWriter) {
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
- if (systemLanguageTags.size() <= 1) {
- printWriter.println("Skipped evaluation as there are less than two system languages.");
- return;
- }
- Set<String> languagesToEvaluate =
- mLanguageProfileDatabase.languageInfoDao().getAll().stream()
- .map(LanguageSignalInfo::getLanguageTag)
- .collect(Collectors.toSet());
- languagesToEvaluate.addAll(systemLanguageTags);
- LanguageProficiencyEvaluator evaluator =
- new LanguageProficiencyEvaluator(mSystemLanguagesProvider);
- LanguageProficiencyAnalyzer[] analyzers =
- new LanguageProficiencyAnalyzer[] {
- new BasicLanguageProficiencyAnalyzer(
- mContext, mTextClassificationConstants, mSystemLanguagesProvider),
- new KmeansLanguageProficiencyAnalyzer(
- mContext, mTextClassificationConstants, mSystemLanguagesProvider),
- mProficiencyAnalyzer
- };
- for (LanguageProficiencyAnalyzer analyzer : analyzers) {
- LanguageProficiencyEvaluator.EvaluationResult result =
- evaluator.evaluate(analyzer, languagesToEvaluate);
- printWriter.println("Evaluation result of " + analyzer.getClass().getSimpleName());
- printWriter.increaseIndent();
- printWriter.printPair(
- "Precision of positive class", result.computePrecisionOfPositiveClass());
- printWriter.printPair(
- "Precision of negative class", result.computePrecisionOfNegativeClass());
- printWriter.printPair(
- "Recall of positive class", result.computeRecallOfPositiveClass());
- printWriter.printPair(
- "Recall of negative class", result.computeRecallOfNegativeClass());
- printWriter.printPair(
- "F1 score of positive class", result.computeF1ScoreOfPositiveClass());
- printWriter.printPair(
- "F1 score of negative class", result.computeF1ScoreOfNegativeClass());
- printWriter.decreaseIndent();
- }
+ private void dumpFrequentLanguages(
+ IndentingPrintWriter printWriter, @LanguageSignalInfo.Source int source) {
+ printWriter.increaseIndent();
+ for (Entity frequentLanguage : getFrequentLanguages(source)) {
+ printWriter.println(frequentLanguage.toString());
}
-
- private void dumpFrequentLanguages(
- IndentingPrintWriter printWriter, @LanguageSignalInfo.Source int source) {
- printWriter.increaseIndent();
- for (Entity frequentLanguage : getFrequentLanguages(source)) {
- printWriter.println(frequentLanguage.toString());
- }
- printWriter.decreaseIndent();
- }
+ printWriter.decreaseIndent();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java b/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java
index 55ba729..56f878f 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,158 +19,154 @@
import android.content.Context;
import android.util.LruCache;
import android.view.textclassifier.ConversationActions;
-
-import androidx.annotation.VisibleForTesting;
-
import com.android.textclassifier.TcLog;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
import com.android.textclassifier.utils.IndentingPrintWriter;
-
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
-
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.time.Instant;
-import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.function.Function;
-
import javax.annotation.Nullable;
/** Class implementing functions which builds and updates user language profile. */
public class LanguageProfileUpdater {
- private static final String TAG = "LanguageProfileUpdater";
- private static final int MAX_CACHE_SIZE = 20;
- private static final String DEFAULT_NOTIFICATION_KEY = "DEFAULT_KEY";
+ private static final String TAG = "LanguageProfileUpdater";
+ private static final int MAX_CACHE_SIZE = 20;
+ private static final String DEFAULT_NOTIFICATION_KEY = "DEFAULT_KEY";
- static final String NOTIFICATION_KEY = "notificationKey";
+ static final String NOTIFICATION_KEY = "notificationKey";
- private final LanguageProfileDatabase mLanguageProfileDatabase;
- private final ListeningExecutorService mExecutorService;
- private final LruCache<String, Long> mUpdatedNotifications = new LruCache<>(MAX_CACHE_SIZE);
+ private final LanguageProfileDatabase languageProfileDatabase;
+ private final ListeningExecutorService executorService;
+ private final LruCache<String, Long> updatedNotifications = new LruCache<>(MAX_CACHE_SIZE);
- public LanguageProfileUpdater(Context context, ListeningExecutorService executorService) {
- mLanguageProfileDatabase = LanguageProfileDatabase.getInstance(context);
- mExecutorService = executorService;
- }
+ public LanguageProfileUpdater(Context context, ListeningExecutorService executorService) {
+ languageProfileDatabase = LanguageProfileDatabase.getInstance(context);
+ this.executorService = executorService;
+ }
- @VisibleForTesting
- LanguageProfileUpdater(
- ListeningExecutorService executorService, LanguageProfileDatabase database) {
- mLanguageProfileDatabase = database;
- mExecutorService = executorService;
- }
+ @VisibleForTesting
+ LanguageProfileUpdater(
+ ListeningExecutorService executorService, LanguageProfileDatabase database) {
+ languageProfileDatabase = database;
+ this.executorService = executorService;
+ }
- /** Updates counts of languages found in suggestConversationActions. */
- public ListenableFuture<Void> updateFromConversationActionsAsync(
- ConversationActions.Request request,
- Function<CharSequence, List<String>> languageDetector) {
- return runAsync(
- () -> {
- ConversationActions.Message msg = getMessageFromRequest(request);
- if (msg == null) {
- return null;
- }
- List<String> languageTags = languageDetector.apply(msg.getText().toString());
- String notificationKey =
- request.getExtras()
- .getString(NOTIFICATION_KEY, DEFAULT_NOTIFICATION_KEY);
- Long messageReferenceTime = getMessageReferenceTime(msg);
- if (isNewMessage(notificationKey, messageReferenceTime)) {
- for (String tag : languageTags) {
- increaseSignalCountInDatabase(
- tag, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1);
- }
- }
- return null;
- });
- }
-
- /** Updates counts of languages found in classifyText. */
- public ListenableFuture<Void> updateFromClassifyTextAsync(List<String> detectedLanguageTags) {
- return runAsync(
- () -> {
- for (String languageTag : detectedLanguageTags) {
- increaseSignalCountInDatabase(
- languageTag, LanguageSignalInfo.CLASSIFY_TEXT, /* increment= */ 1);
- }
- return null;
- });
- }
-
- /** Runs the specified callable asynchronously and prints the stack trace if it failed. */
- private <T> ListenableFuture<T> runAsync(Callable<T> callable) {
- ListenableFuture<T> future = mExecutorService.submit(callable);
- Futures.addCallback(
- future,
- new FutureCallback<T>() {
- @Override
- public void onSuccess(T result) {}
-
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "runAsync", t);
- }
- },
- MoreExecutors.directExecutor());
- return future;
- }
-
- private void increaseSignalCountInDatabase(
- String languageTag, @LanguageSignalInfo.Source int sourceType, int increment) {
- mLanguageProfileDatabase
- .languageInfoDao()
- .increaseSignalCount(languageTag, sourceType, increment);
- }
-
- @Nullable
- private ConversationActions.Message getMessageFromRequest(ConversationActions.Request request) {
- int size = request.getConversation().size();
- if (size == 0) {
+ /** Updates counts of languages found in suggestConversationActions. */
+ @CanIgnoreReturnValue
+ public ListenableFuture<Void> updateFromConversationActionsAsync(
+ ConversationActions.Request request, Function<CharSequence, List<String>> languageDetector) {
+ return runAsync(
+ () -> {
+ ConversationActions.Message msg = getMessageFromRequest(request);
+ if (msg == null) {
return null;
- }
- return request.getConversation().get(size - 1);
- }
+ }
+ List<String> languageTags = languageDetector.apply(msg.getText().toString());
+ String notificationKey =
+ request.getExtras().getString(NOTIFICATION_KEY, DEFAULT_NOTIFICATION_KEY);
+ Long messageReferenceTime = getMessageReferenceTime(msg);
+ if (isNewMessage(notificationKey, messageReferenceTime)) {
+ for (String tag : languageTags) {
+ increaseSignalCountInDatabase(
+ tag, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1);
+ }
+ }
+ return null;
+ });
+ }
- private boolean isNewMessage(String notificationKey, Long sendTime) {
- Long oldTime = mUpdatedNotifications.get(notificationKey);
+ /** Updates counts of languages found in classifyText. */
+ @CanIgnoreReturnValue
+ public ListenableFuture<Void> updateFromClassifyTextAsync(List<String> detectedLanguageTags) {
+ return runAsync(
+ () -> {
+ for (String languageTag : detectedLanguageTags) {
+ increaseSignalCountInDatabase(
+ languageTag, LanguageSignalInfo.CLASSIFY_TEXT, /* increment= */ 1);
+ }
+ return null;
+ });
+ }
- if (oldTime == null || sendTime > oldTime) {
- mUpdatedNotifications.put(notificationKey, sendTime);
- return true;
- }
- return false;
- }
+ /** Runs the specified callable asynchronously and prints the stack trace if it failed. */
+ private <T> ListenableFuture<T> runAsync(Callable<T> callable) {
+ ListenableFuture<T> future = executorService.submit(callable);
+ Futures.addCallback(
+ future,
+ new FutureCallback<T>() {
+ @Override
+ public void onSuccess(T result) {}
- private long getMessageReferenceTime(ConversationActions.Message msg) {
- return msg.getReferenceTime() == null
- ? OffsetDateTime.now().toInstant().toEpochMilli()
- : msg.getReferenceTime().toInstant().toEpochMilli();
- }
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "runAsync", t);
+ }
+ },
+ MoreExecutors.directExecutor());
+ return future;
+ }
- /** Dumps the data on the screen when called. */
- public void dump(IndentingPrintWriter printWriter) {
- printWriter.println("LanguageProfileUpdater:");
- printWriter.increaseIndent();
- printWriter.println("Cache for notifications status:");
- printWriter.increaseIndent();
- for (Map.Entry<String, Long> entry : mUpdatedNotifications.snapshot().entrySet()) {
- long timestamp = entry.getValue();
- printWriter.println(
- "Notification key: "
- + entry.getKey()
- + " time: "
- + timestamp
- + " ("
- + Instant.ofEpochMilli(timestamp).toString()
- + ")");
- }
- printWriter.decreaseIndent();
- printWriter.decreaseIndent();
+ private void increaseSignalCountInDatabase(
+ String languageTag, @LanguageSignalInfo.Source int sourceType, int increment) {
+ languageProfileDatabase
+ .languageInfoDao()
+ .increaseSignalCount(languageTag, sourceType, increment);
+ }
+
+ @Nullable
+ private static ConversationActions.Message getMessageFromRequest(
+ ConversationActions.Request request) {
+ int size = request.getConversation().size();
+ if (size == 0) {
+ return null;
}
+ return request.getConversation().get(size - 1);
+ }
+
+ private boolean isNewMessage(String notificationKey, Long sendTime) {
+ Long oldTime = updatedNotifications.get(notificationKey);
+
+ if (oldTime == null || sendTime > oldTime) {
+ updatedNotifications.put(notificationKey, sendTime);
+ return true;
+ }
+ return false;
+ }
+
+ private static long getMessageReferenceTime(ConversationActions.Message msg) {
+ return msg.getReferenceTime() == null
+ ? Instant.now().toEpochMilli()
+ : msg.getReferenceTime().toInstant().toEpochMilli();
+ }
+
+ /** Dumps the data on the screen when called. */
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("LanguageProfileUpdater:");
+ printWriter.increaseIndent();
+ printWriter.println("Cache for notifications status:");
+ printWriter.increaseIndent();
+ for (Map.Entry<String, Long> entry : updatedNotifications.snapshot().entrySet()) {
+ long timestamp = entry.getValue();
+ printWriter.println(
+ "Notification key: "
+ + entry.getKey()
+ + " time: "
+ + timestamp
+ + " ("
+ + Instant.ofEpochMilli(timestamp).toString()
+ + ")");
+ }
+ printWriter.decreaseIndent();
+ printWriter.decreaseIndent();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java b/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java
index c3ad77a..8b678b9 100644
--- a/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java
+++ b/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,103 +23,99 @@
import android.location.Location;
import android.location.LocationManager;
import android.telephony.TelephonyManager;
-
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.TcLog;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
final class LocationSignalProvider {
- private static final String TAG = "LocationSignalProvider";
+ private static final String TAG = "LocationSignalProvider";
- private static final long EXPIRATION_TIME = TimeUnit.DAYS.toMillis(1);
- private final LocationManager mLocationManager;
- private final TelephonyManager mTelephonyManager;
- private final Geocoder mGeocoder;
- private final Object mLock = new Object();
+ private static final long EXPIRATION_TIME = TimeUnit.DAYS.toMillis(1);
+ private final LocationManager locationManager;
+ private final TelephonyManager telephonyManager;
+ private final Geocoder geocoder;
+ private final Object lock = new Object();
- private String mCachedLanguageTag = null;
- private long mLastModifiedTime;
+ private String cachedLanguageTag = null;
+ private long lastModifiedTime;
- LocationSignalProvider(Context context) {
- this(
- context.getSystemService(LocationManager.class),
- context.getSystemService(TelephonyManager.class),
- new Geocoder(context));
+ LocationSignalProvider(Context context) {
+ this(
+ context.getSystemService(LocationManager.class),
+ context.getSystemService(TelephonyManager.class),
+ new Geocoder(context));
+ }
+
+ @VisibleForTesting
+ LocationSignalProvider(
+ LocationManager locationManager, TelephonyManager telephonyManager, Geocoder geocoder) {
+ this.locationManager = Preconditions.checkNotNull(locationManager);
+ this.telephonyManager = Preconditions.checkNotNull(telephonyManager);
+ this.geocoder = Preconditions.checkNotNull(geocoder);
+ }
+
+ /**
+ * Deduces the language by using user's current location as a signal and returns a BCP 47 language
+ * code.
+ */
+ @Nullable
+ String detectLanguageTag() {
+ synchronized (lock) {
+ if ((System.currentTimeMillis() - lastModifiedTime) < EXPIRATION_TIME) {
+ return cachedLanguageTag;
+ }
+ cachedLanguageTag = detectLanguageCodeInternal();
+ lastModifiedTime = System.currentTimeMillis();
+ return cachedLanguageTag;
}
+ }
- @VisibleForTesting
- LocationSignalProvider(
- LocationManager locationManager, TelephonyManager telephonyManager, Geocoder geocoder) {
- mLocationManager = Preconditions.checkNotNull(locationManager);
- mTelephonyManager = Preconditions.checkNotNull(telephonyManager);
- mGeocoder = Preconditions.checkNotNull(geocoder);
+ @Nullable
+ private String detectLanguageCodeInternal() {
+ String currentCountryCode = detectCurrentCountryCode();
+ if (currentCountryCode == null) {
+ return null;
}
+ return toLanguageTag(currentCountryCode);
+ }
- /**
- * Deduces the language by using user's current location as a signal and returns a BCP 47
- * language code.
- */
- @Nullable
- String detectLanguageTag() {
- synchronized (mLock) {
- if ((System.currentTimeMillis() - mLastModifiedTime) < EXPIRATION_TIME) {
- return mCachedLanguageTag;
- }
- mCachedLanguageTag = detectLanguageCodeInternal();
- mLastModifiedTime = System.currentTimeMillis();
- return mCachedLanguageTag;
- }
+ @Nullable
+ private String detectCurrentCountryCode() {
+ String networkCountryCode = telephonyManager.getNetworkCountryIso();
+ if (networkCountryCode != null) {
+ return networkCountryCode;
}
+ return detectCurrentCountryFromLocationManager();
+ }
- @Nullable
- private String detectLanguageCodeInternal() {
- String currentCountryCode = detectCurrentCountryCode();
- if (currentCountryCode == null) {
- return null;
- }
- return toLanguageTag(currentCountryCode);
+ @Nullable
+ private String detectCurrentCountryFromLocationManager() {
+ Location location = locationManager.getLastKnownLocation(LocationManager.PASSIVE_PROVIDER);
+ if (location == null) {
+ return null;
}
+ try {
+ List<Address> addresses =
+ geocoder.getFromLocation(
+ location.getLatitude(), location.getLongitude(), /*maxResults=*/ 1);
+ if (addresses != null && !addresses.isEmpty()) {
+ return addresses.get(0).getCountryCode();
+ }
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to call getFromLocation: ", e);
+ return null;
+ }
+ return null;
+ }
- @Nullable
- private String detectCurrentCountryCode() {
- String networkCountryCode = mTelephonyManager.getNetworkCountryIso();
- if (networkCountryCode != null) {
- return networkCountryCode;
- }
- return detectCurrentCountryFromLocationManager();
- }
-
- @Nullable
- private String detectCurrentCountryFromLocationManager() {
- Location location = mLocationManager.getLastKnownLocation(LocationManager.PASSIVE_PROVIDER);
- if (location == null) {
- return null;
- }
- try {
- List<Address> addresses =
- mGeocoder.getFromLocation(
- location.getLatitude(), location.getLongitude(), /*maxResults=*/ 1);
- if (addresses != null && !addresses.isEmpty()) {
- return addresses.get(0).getCountryCode();
- }
- } catch (IOException e) {
- TcLog.e(TAG, "Failed to call getFromLocation: ", e);
- return null;
- }
- return null;
- }
-
- @Nullable
- private static String toLanguageTag(String countryTag) {
- ULocale locale = new ULocale.Builder().setRegion(countryTag).build();
- locale = ULocale.addLikelySubtags(locale);
- return locale.getLanguage();
- }
+ @Nullable
+ private static String toLanguageTag(String countryTag) {
+ ULocale locale = new ULocale.Builder().setRegion(countryTag).build();
+ locale = ULocale.addLikelySubtags(locale);
+ return locale.getLanguage();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java
index 851e81b..d23be60 100644
--- a/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,12 +19,9 @@
import android.content.Context;
import android.content.SharedPreferences;
import android.view.textclassifier.TextClassifierEvent;
-
import com.android.textclassifier.TcLog;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import org.json.JSONException;
import org.json.JSONObject;
@@ -37,159 +34,159 @@
* translation action in the future.
*/
class ReinforcementLanguageProficiencyAnalyzer implements LanguageProficiencyAnalyzer {
- private static final String TAG = "ReinforcementAnalyzer";
- private static final String PREF_NAME = "ulp-reinforcement-analyzer";
- private static final float SHOW_TRANSLATE_ACTION_THRESHOLD = 0.9f;
- private static final int MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT = 30;
+ private static final String TAG = "ReinforcementAnalyzer";
+ private static final String PREF_NAME = "ulp-reinforcement-analyzer";
+ private static final float SHOW_TRANSLATE_ACTION_THRESHOLD = 0.9f;
+ private static final int MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT = 30;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
- private final SharedPreferences mSharedPreferences;
+ private final SystemLanguagesProvider systemLanguagesProvider;
+ private final SharedPreferences sharedPreferences;
- ReinforcementLanguageProficiencyAnalyzer(
- Context context, SystemLanguagesProvider systemLanguagesProvider) {
- Preconditions.checkNotNull(context);
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mSharedPreferences = context.getSharedPreferences(PREF_NAME, Context.MODE_PRIVATE);
+ ReinforcementLanguageProficiencyAnalyzer(
+ Context context, SystemLanguagesProvider systemLanguagesProvider) {
+ Preconditions.checkNotNull(context);
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ sharedPreferences = context.getSharedPreferences(PREF_NAME, Context.MODE_PRIVATE);
+ }
+
+ @VisibleForTesting
+ ReinforcementLanguageProficiencyAnalyzer(
+ SystemLanguagesProvider systemLanguagesProvider, SharedPreferences sharedPreferences) {
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ this.sharedPreferences = Preconditions.checkNotNull(sharedPreferences);
+ }
+
+ @Override
+ public float canUnderstand(String languageTag) {
+ TranslationStatistics translationStatistics =
+ TranslationStatistics.loadFromSharedPreference(sharedPreferences, languageTag);
+ if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
+ return systemLanguagesProvider.getSystemLanguageTags().contains(languageTag) ? 1f : 0f;
+ }
+ return translationStatistics.getScore();
+ }
+
+ @Override
+ public boolean shouldShowTranslation(String languageTag) {
+ TranslationStatistics translationStatistics =
+ TranslationStatistics.loadFromSharedPreference(sharedPreferences, languageTag);
+ if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
+ // Show translate action until we have enough feedback.
+ return true;
+ }
+ return translationStatistics.getScore() <= SHOW_TRANSLATE_ACTION_THRESHOLD;
+ }
+
+ @Override
+ public void onTextClassifierEvent(TextClassifierEvent event) {
+ if (event.getEventCategory() == TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION) {
+ if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION
+ || event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
+ onTranslateEvent(event);
+ }
+ }
+ }
+
+ private void onTranslateEvent(TextClassifierEvent event) {
+ if (event.getEntityTypes().length == 0) {
+ return;
+ }
+ String languageTag = event.getEntityTypes()[0];
+ // We only count the case that we show translate action in the prime position.
+ if (event.getActionIndices().length == 0 || event.getActionIndices()[0] != 0) {
+ return;
+ }
+ TranslationStatistics translationStatistics =
+ TranslationStatistics.loadFromSharedPreference(sharedPreferences, languageTag);
+ if (event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
+ translationStatistics.increaseShownCountByOne();
+ } else if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION) {
+ translationStatistics.increaseClickedCountByOne();
+ }
+ translationStatistics.save(sharedPreferences, languageTag);
+ }
+
+ private static final class TranslationStatistics {
+ private static final String SEEN_COUNT = "seen_count";
+ private static final String CLICK_COUNT = "click_count";
+
+ private int shownCount;
+ private int clickCount;
+
+ private TranslationStatistics() {
+ this(/* seenCount= */ 0, /* clickCount= */ 0);
}
- @VisibleForTesting
- ReinforcementLanguageProficiencyAnalyzer(
- SystemLanguagesProvider systemLanguagesProvider, SharedPreferences sharedPreferences) {
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mSharedPreferences = Preconditions.checkNotNull(sharedPreferences);
+ private TranslationStatistics(int seenCount, int clickCount) {
+ shownCount = seenCount;
+ this.clickCount = clickCount;
+ }
+
+ static TranslationStatistics loadFromSharedPreference(
+ SharedPreferences sharedPreferences, String languageTag) {
+ String serializedString = sharedPreferences.getString(languageTag, null);
+ return TranslationStatistics.fromSerializedString(serializedString);
+ }
+
+ void save(SharedPreferences sharedPreferences, String languageTag) {
+ // TODO: Consider to store it in a database.
+ sharedPreferences.edit().putString(languageTag, serializeToString()).apply();
+ }
+
+ private String serializeToString() {
+ try {
+ JSONObject jsonObject = new JSONObject();
+ jsonObject.put(SEEN_COUNT, shownCount);
+ jsonObject.put(CLICK_COUNT, clickCount);
+ return jsonObject.toString();
+ } catch (JSONException ex) {
+ TcLog.e(TAG, "serializeToString: ", ex);
+ }
+ return "";
+ }
+
+ void increaseShownCountByOne() {
+ shownCount += 1;
+ }
+
+ void increaseClickedCountByOne() {
+ clickCount += 1;
+ }
+
+ float getScore() {
+ if (shownCount == 0) {
+ return 0f;
+ }
+ return clickCount / (float) shownCount;
+ }
+
+ int getShownCount() {
+ return shownCount;
+ }
+
+ static TranslationStatistics fromSerializedString(String str) {
+ if (str == null) {
+ return new TranslationStatistics();
+ }
+ try {
+ JSONObject jsonObject = new JSONObject(str);
+ int seenCount = jsonObject.getInt(SEEN_COUNT);
+ int clickCount = jsonObject.getInt(CLICK_COUNT);
+ return new TranslationStatistics(seenCount, clickCount);
+ } catch (JSONException ex) {
+ TcLog.e(TAG, "Failed to parse " + str, ex);
+ }
+ return new TranslationStatistics();
}
@Override
- public float canUnderstand(String languageTag) {
- TranslationStatistics translationStatistics =
- TranslationStatistics.loadFromSharedPreference(mSharedPreferences, languageTag);
- if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
- return mSystemLanguagesProvider.getSystemLanguageTags().contains(languageTag) ? 1f : 0f;
- }
- return translationStatistics.getScore();
+ public String toString() {
+ return "TranslationStatistics{"
+ + "mShownCount="
+ + shownCount
+ + ", mClickCount="
+ + clickCount
+ + '}';
}
-
- @Override
- public boolean shouldShowTranslation(String languageTag) {
- TranslationStatistics translationStatistics =
- TranslationStatistics.loadFromSharedPreference(mSharedPreferences, languageTag);
- if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
- // Show translate action until we have enough feedback.
- return true;
- }
- return translationStatistics.getScore() <= SHOW_TRANSLATE_ACTION_THRESHOLD;
- }
-
- @Override
- public void onTextClassifierEvent(TextClassifierEvent event) {
- if (event.getEventCategory() == TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION) {
- if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION
- || event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
- onTranslateEvent(event);
- }
- }
- }
-
- private void onTranslateEvent(TextClassifierEvent event) {
- if (event.getEntityTypes().length == 0) {
- return;
- }
- String languageTag = event.getEntityTypes()[0];
- // We only count the case that we show translate action in the prime position.
- if (event.getActionIndices().length == 0 || event.getActionIndices()[0] != 0) {
- return;
- }
- TranslationStatistics translationStatistics =
- TranslationStatistics.loadFromSharedPreference(mSharedPreferences, languageTag);
- if (event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
- translationStatistics.increaseShownCountByOne();
- } else if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION) {
- translationStatistics.increaseClickedCountByOne();
- }
- translationStatistics.save(mSharedPreferences, languageTag);
- }
-
- private static final class TranslationStatistics {
- private static final String SEEN_COUNT = "seen_count";
- private static final String CLICK_COUNT = "click_count";
-
- private int mShownCount;
- private int mClickCount;
-
- private TranslationStatistics() {
- this(/* seenCount= */ 0, /* clickCount= */ 0);
- }
-
- private TranslationStatistics(int seenCount, int clickCount) {
- mShownCount = seenCount;
- mClickCount = clickCount;
- }
-
- static TranslationStatistics loadFromSharedPreference(
- SharedPreferences sharedPreferences, String languageTag) {
- String serializedString = sharedPreferences.getString(languageTag, null);
- return TranslationStatistics.fromSerializedString(serializedString);
- }
-
- void save(SharedPreferences sharedPreferences, String languageTag) {
- // TODO: Consider to store it in a database.
- sharedPreferences.edit().putString(languageTag, serializeToString()).apply();
- }
-
- private String serializeToString() {
- try {
- JSONObject jsonObject = new JSONObject();
- jsonObject.put(SEEN_COUNT, mShownCount);
- jsonObject.put(CLICK_COUNT, mClickCount);
- return jsonObject.toString();
- } catch (JSONException ex) {
- TcLog.e(TAG, "serializeToString: ", ex);
- }
- return "";
- }
-
- void increaseShownCountByOne() {
- mShownCount += 1;
- }
-
- void increaseClickedCountByOne() {
- mClickCount += 1;
- }
-
- float getScore() {
- if (mShownCount == 0) {
- return 0f;
- }
- return mClickCount / (float) mShownCount;
- }
-
- int getShownCount() {
- return mShownCount;
- }
-
- static TranslationStatistics fromSerializedString(String str) {
- if (str == null) {
- return new TranslationStatistics();
- }
- try {
- JSONObject jsonObject = new JSONObject(str);
- int seenCount = jsonObject.getInt(SEEN_COUNT);
- int clickCount = jsonObject.getInt(CLICK_COUNT);
- return new TranslationStatistics(seenCount, clickCount);
- } catch (JSONException ex) {
- TcLog.e(TAG, "Failed to parse " + str, ex);
- }
- return new TranslationStatistics();
- }
-
- @Override
- public String toString() {
- return "TranslationStatistics{"
- + "mShownCount="
- + mShownCount
- + ", mClickCount="
- + mClickCount
- + '}';
- }
- }
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java b/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java
index e167566..eefa913 100644
--- a/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java
+++ b/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,18 +18,17 @@
import android.content.res.Resources;
import android.os.LocaleList;
-
import java.util.ArrayList;
import java.util.List;
final class SystemLanguagesProvider {
- List<String> getSystemLanguageTags() {
- LocaleList localeList = Resources.getSystem().getConfiguration().getLocales();
- List<String> languageTags = new ArrayList<>();
- for (int i = 0; i < localeList.size(); i++) {
- languageTags.add(localeList.get(i).getLanguage());
- }
- return languageTags;
+ List<String> getSystemLanguageTags() {
+ LocaleList localeList = Resources.getSystem().getConfiguration().getLocales();
+ List<String> languageTags = new ArrayList<>();
+ for (int i = 0; i < localeList.size(); i++) {
+ languageTags.add(localeList.get(i).getLanguage());
}
+ return languageTags;
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java b/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java
index 2ae317e..2e213d5 100644
--- a/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java
+++ b/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
package com.android.textclassifier.ulp.database;
import android.content.Context;
-
import androidx.annotation.GuardedBy;
import androidx.room.Database;
import androidx.room.Room;
@@ -30,36 +29,36 @@
* existing if there is one) and use it.
*/
@Database(
- entities = {LanguageSignalInfo.class},
- version = 1,
- exportSchema = false)
+ entities = {LanguageSignalInfo.class},
+ version = 1,
+ exportSchema = false)
public abstract class LanguageProfileDatabase extends RoomDatabase {
- private static final Object sLock = new Object();
+ private static final Object lock = new Object();
- @GuardedBy("sLock")
- private static LanguageProfileDatabase sINSTANCE;
+ @GuardedBy("lock")
+ private static LanguageProfileDatabase instance;
- /**
- * Returns {@link LanguageSignalInfoDao} object belonging to the {@link LanguageProfileDatabase}
- * with which we can call database queries.
- */
- public abstract LanguageSignalInfoDao languageInfoDao();
+ /**
+ * Returns {@link LanguageSignalInfoDao} object belonging to the {@link LanguageProfileDatabase}
+ * with which we can call database queries.
+ */
+ public abstract LanguageSignalInfoDao languageInfoDao();
- /**
- * Create an instance of {@link LanguageProfileDatabase} for chosen context or existing one if
- * it was already created.
- */
- public static LanguageProfileDatabase getInstance(final Context context) {
- synchronized (sLock) {
- if (sINSTANCE == null) {
- sINSTANCE =
- Room.databaseBuilder(
- context.getApplicationContext(),
- LanguageProfileDatabase.class,
- "language_profile")
- .build();
- }
- return sINSTANCE;
- }
+ /**
+ * Create an instance of {@link LanguageProfileDatabase} for chosen context or existing one if it
+ * was already created.
+ */
+ public static LanguageProfileDatabase getInstance(final Context context) {
+ synchronized (lock) {
+ if (instance == null) {
+ instance =
+ Room.databaseBuilder(
+ context.getApplicationContext(),
+ LanguageProfileDatabase.class,
+ "language_profile")
+ .build();
+ }
+ return instance;
}
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java
index b494e93..98f63ac 100644
--- a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java
+++ b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,10 +18,9 @@
import androidx.annotation.IntDef;
import androidx.annotation.NonNull;
-import androidx.core.util.Preconditions;
import androidx.room.ColumnInfo;
import androidx.room.Entity;
-
+import com.google.common.base.Preconditions;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@@ -32,77 +31,78 @@
* specified language.
*/
@Entity(
- tableName = "language_signal_infos",
- primaryKeys = {"languageTag", "source"})
+ tableName = "language_signal_infos",
+ primaryKeys = {"languageTag", "source"})
public final class LanguageSignalInfo {
- @Retention(RetentionPolicy.SOURCE)
- @IntDef({SUGGEST_CONVERSATION_ACTIONS, CLASSIFY_TEXT})
- public @interface Source {}
+ /** The source of the signal */
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef({SUGGEST_CONVERSATION_ACTIONS, CLASSIFY_TEXT})
+ public @interface Source {}
- public static final int SUGGEST_CONVERSATION_ACTIONS = 0;
- public static final int CLASSIFY_TEXT = 1;
+ public static final int SUGGEST_CONVERSATION_ACTIONS = 0;
+ public static final int CLASSIFY_TEXT = 1;
- @NonNull
- @ColumnInfo(name = "languageTag")
- private String mLanguageTag;
+ @NonNull
+ @ColumnInfo(name = "languageTag")
+ private final String languageTag;
- @ColumnInfo(name = "source")
- private int mSource;
+ @ColumnInfo(name = "source")
+ private final int source;
- @ColumnInfo(name = "count")
- private int mCount;
+ @ColumnInfo(name = "count")
+ private final int count;
- public LanguageSignalInfo(String languageTag, @Source int source, int count) {
- mLanguageTag = Preconditions.checkNotNull(languageTag);
- mSource = source;
- mCount = count;
+ public LanguageSignalInfo(String languageTag, @Source int source, int count) {
+ this.languageTag = Preconditions.checkNotNull(languageTag);
+ this.source = source;
+ this.count = count;
+ }
+
+ public String getLanguageTag() {
+ return languageTag;
+ }
+
+ @Source
+ public int getSource() {
+ return source;
+ }
+
+ public int getCount() {
+ return count;
+ }
+
+ @Override
+ public String toString() {
+ String src = "OTHER";
+ if (source == SUGGEST_CONVERSATION_ACTIONS) {
+ src = "SUGGEST_CONVERSATION_ACTIONS";
+ } else if (source == CLASSIFY_TEXT) {
+ src = "CLASSIFY_TEXT";
}
- public String getLanguageTag() {
- return mLanguageTag;
- }
+ return languageTag + "_" + src + ": " + count;
+ }
- @Source
- public int getSource() {
- return mSource;
- }
+ @Override
+ public int hashCode() {
+ int result = languageTag.hashCode();
+ result = 31 * result + Integer.hashCode(source);
+ result = 31 * result + Integer.hashCode(count);
+ return result;
+ }
- public int getCount() {
- return mCount;
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
}
-
- @Override
- public String toString() {
- String src = "OTHER";
- if (mSource == SUGGEST_CONVERSATION_ACTIONS) {
- src = "SUGGEST_CONVERSATION_ACTIONS";
- } else if (mSource == CLASSIFY_TEXT) {
- src = "CLASSIFY_TEXT";
- }
-
- return mLanguageTag + "_" + src + ": " + mCount;
+ if (obj == null || obj.getClass() != getClass()) {
+ return false;
}
-
- @Override
- public int hashCode() {
- int result = mLanguageTag.hashCode();
- result = 31 * result + Integer.hashCode(mSource);
- result = 31 * result + Integer.hashCode(mCount);
- return result;
- }
-
- @Override
- public boolean equals(Object obj) {
- if (this == obj) {
- return true;
- }
- if (obj == null || obj.getClass() != getClass()) {
- return false;
- }
- LanguageSignalInfo info = (LanguageSignalInfo) obj;
- return mLanguageTag.equals(info.getLanguageTag())
- && mSource == info.getSource()
- && mCount == info.getCount();
- }
+ LanguageSignalInfo info = (LanguageSignalInfo) obj;
+ return languageTag.equals(info.getLanguageTag())
+ && source == info.getSource()
+ && count == info.getCount();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java
index adc1a59..ca792b2 100644
--- a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java
+++ b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,9 +20,6 @@
import androidx.room.Insert;
import androidx.room.OnConflictStrategy;
import androidx.room.Query;
-
-import com.google.common.annotations.VisibleForTesting;
-
import java.util.List;
/**
@@ -34,31 +31,30 @@
@Dao
public interface LanguageSignalInfoDao {
- /**
- * Inserts {@link LanguageSignalInfo} object into the Room database. If there was already entity
- * with the same language local and source type replaces it with a passed one.
- *
- * @param languageSignalInfo object to insert into the database.
- */
- @Insert(onConflict = OnConflictStrategy.REPLACE)
- @VisibleForTesting
- void insertLanguageInfo(LanguageSignalInfo languageSignalInfo);
+ /**
+ * Inserts {@link LanguageSignalInfo} object into the Room database. If there was already entity
+ * with the same language local and source type replaces it with a passed one.
+ *
+ * @param languageSignalInfo object to insert into the database.
+ */
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ void insertLanguageInfo(LanguageSignalInfo languageSignalInfo);
- /** Returns all the {@link LanguageSignalInfo} objects which have source like {@code src}. */
- @Query("SELECT * FROM language_signal_infos WHERE source = :src")
- List<LanguageSignalInfo> getBySource(@LanguageSignalInfo.Source int src);
+ /** Returns all the {@link LanguageSignalInfo} objects which have source like {@code src}. */
+ @Query("SELECT * FROM language_signal_infos WHERE source = :src")
+ List<LanguageSignalInfo> getBySource(@LanguageSignalInfo.Source int src);
- /** Returns all the {@link LanguageSignalInfo} objects stored in the database. */
- @Query("SELECT * FROM language_signal_infos")
- List<LanguageSignalInfo> getAll();
+ /** Returns all the {@link LanguageSignalInfo} objects stored in the database. */
+ @Query("SELECT * FROM language_signal_infos")
+ List<LanguageSignalInfo> getAll();
- /**
- * Increases the count of the specified signal by the specified increment or inserts a new entry
- * if the signal is not in the database yet.
- */
- @Query(
- "INSERT INTO language_signal_infos VALUES(:languageTag, :source, :increment)"
- + " ON CONFLICT(languageTag, source) DO UPDATE SET count = count + :increment")
- void increaseSignalCount(
- String languageTag, @LanguageSignalInfo.Source int source, int increment);
+ /**
+ * Increases the count of the specified signal by the specified increment or inserts a new entry
+ * if the signal is not in the database yet.
+ */
+ @Query(
+ "INSERT INTO language_signal_infos VALUES(:languageTag, :source, :increment)"
+ + " ON CONFLICT(languageTag, source) DO UPDATE SET count = count + :increment")
+ void increaseSignalCount(
+ String languageTag, @LanguageSignalInfo.Source int source, int increment);
}
diff --git a/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java b/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java
index e77b05e..a01af54 100644
--- a/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java
+++ b/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,11 +17,7 @@
package com.android.textclassifier.ulp.kmeans;
import android.util.Log;
-
-import androidx.annotation.NonNull;
-
import com.google.common.annotations.VisibleForTesting;
-
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -29,226 +25,218 @@
/** Simple K-Means implementation which found in android internal ml library */
public class KMeans {
- private static final boolean DEBUG = false;
- private static final String TAG = "KMeans";
- private final Random mRandomState;
- private final int mMaxIterations;
- private float mSqConvergenceEpsilon;
+ private static final boolean DEBUG = false;
+ private static final String TAG = "KMeans";
+ private final Random randomState;
+ private final int maxIterations;
+ private final float sqConvergenceEpsilon;
- public KMeans() {
- this(new Random());
+ public KMeans() {
+ this(new Random());
+ }
+
+ public KMeans(Random random) {
+ this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
+ }
+
+ public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
+ randomState = random;
+ this.maxIterations = maxIterations;
+ sqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
+ }
+
+ /**
+ * Runs k-means on the input data (X) trying to find k means.
+ *
+ * <p>K-Means is known for getting stuck into local optima, so you might want to run it multiple
+ * time and argmax on {@link KMeans#score(List)}
+ *
+ * @param k The number of points to return.
+ * @param inputData Input data.
+ * @return An array of k Means, each representing a centroid and data points that belong to it.
+ */
+ public List<Mean> predict(final int k, final float[][] inputData) {
+ checkDataSetSanity(inputData);
+ int dimension = inputData[0].length;
+
+ final ArrayList<Mean> means = new ArrayList<>();
+ for (int i = 0; i < k; i++) {
+ Mean m = new Mean(dimension);
+ for (int j = 0; j < dimension; j++) {
+ m.centroid[j] = randomState.nextFloat();
+ }
+ means.add(m);
}
- public KMeans(Random random) {
- this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
+ // Iterate until we converge or run out of iterations
+ boolean converged = false;
+ for (int i = 0; i < maxIterations; i++) {
+ converged = step(means, inputData);
+ if (converged) {
+ if (DEBUG) {
+ Log.d(TAG, "Converged at iteration: " + i);
+ }
+ break;
+ }
+ }
+ if (!converged && DEBUG) {
+ Log.d(TAG, "Did not converge");
}
- public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
- mRandomState = random;
- mMaxIterations = maxIterations;
- mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
+ return means;
+ }
+
+ /**
+ * Score calculates the inertia between means. This can be considered as an E step of an EM
+ * algorithm.
+ *
+ * @param means Means to use when calculating score.
+ * @return The score
+ */
+ public static double score(List<Mean> means) {
+ double score = 0;
+ final int meansSize = means.size();
+ for (int i = 0; i < meansSize; i++) {
+ Mean mean = means.get(i);
+ for (int j = 0; j < meansSize; j++) {
+ Mean compareTo = means.get(j);
+ if (mean == compareTo) {
+ continue;
+ }
+ double distance = Math.sqrt(sqDistance(mean.centroid, compareTo.centroid));
+ score += distance;
+ }
+ }
+ return score;
+ }
+
+ /** */
+ @VisibleForTesting
+ public void checkDataSetSanity(float[][] inputData) {
+ if (inputData == null) {
+ throw new IllegalArgumentException("Data set is null.");
+ } else if (inputData.length == 0) {
+ throw new IllegalArgumentException("Data set is empty.");
+ } else if (inputData[0] == null) {
+ throw new IllegalArgumentException("Bad data set format.");
}
- /**
- * Runs k-means on the input data (X) trying to find k means.
- *
- * <p>K-Means is known for getting stuck into local optima, so you might want to run it multiple
- * time and argmax on {@link KMeans#score(List)}
- *
- * @param k The number of points to return.
- * @param inputData Input data.
- * @return An array of k Means, each representing a centroid and data points that belong to it.
- */
- public List<Mean> predict(final int k, final float[][] inputData) {
- checkDataSetSanity(inputData);
- int dimension = inputData[0].length;
+ final int dimension = inputData[0].length;
+ final int length = inputData.length;
+ for (int i = 1; i < length; i++) {
+ if (inputData[i] == null || inputData[i].length != dimension) {
+ throw new IllegalArgumentException("Bad data set format.");
+ }
+ }
+ }
- final ArrayList<Mean> means = new ArrayList<>();
- for (int i = 0; i < k; i++) {
- Mean m = new Mean(dimension);
- for (int j = 0; j < dimension; j++) {
- m.mCentroid[j] = mRandomState.nextFloat();
- }
- means.add(m);
- }
+ /**
+ * K-Means iteration.
+ *
+ * @param means Current means
+ * @param inputData Input data
+ * @return True if data set converged
+ */
+ private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
- // Iterate until we converge or run out of iterations
- boolean converged = false;
- for (int i = 0; i < mMaxIterations; i++) {
- converged = step(means, inputData);
- if (converged) {
- if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
- break;
- }
- }
- if (!converged && DEBUG) Log.d(TAG, "Did not converge");
-
- return means;
+ // Clean up the previous state because we need to compute
+ // which point belongs to each mean again.
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ mean.closestItems.clear();
+ }
+ for (int i = inputData.length - 1; i >= 0; i--) {
+ final float[] current = inputData[i];
+ final Mean nearest = nearestMean(current, means);
+ nearest.closestItems.add(current);
}
- /**
- * Score calculates the inertia between means. This can be considered as an E step of an EM
- * algorithm.
- *
- * @param means Means to use when calculating score.
- * @return The score
- */
- public static double score(@NonNull List<Mean> means) {
- double score = 0;
- final int meansSize = means.size();
- for (int i = 0; i < meansSize; i++) {
- Mean mean = means.get(i);
- for (int j = 0; j < meansSize; j++) {
- Mean compareTo = means.get(j);
- if (mean == compareTo) {
- continue;
- }
- double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
- score += distance;
- }
+ boolean converged = true;
+ // Move each mean towards the nearest data set points
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ if (mean.closestItems.isEmpty()) {
+ continue;
+ }
+
+ // Compute the new mean centroid:
+ // 1. Sum all all points
+ // 2. Average them
+ final float[] oldCentroid = mean.centroid;
+ mean.centroid = new float[oldCentroid.length];
+ for (int j = 0; j < mean.closestItems.size(); j++) {
+ // Update each centroid component
+ for (int p = 0; p < mean.centroid.length; p++) {
+ mean.centroid[p] += mean.closestItems.get(j)[p];
}
- return score;
+ }
+ for (int j = 0; j < mean.centroid.length; j++) {
+ mean.centroid[j] /= mean.closestItems.size();
+ }
+
+ // We converged if the centroid didn't move for any of the means.
+ if (sqDistance(oldCentroid, mean.centroid) > sqConvergenceEpsilon) {
+ converged = false;
+ }
+ }
+ return converged;
+ }
+
+ /** */
+ @VisibleForTesting
+ public static Mean nearestMean(float[] point, List<Mean> means) {
+ Mean nearest = null;
+ float nearestDistance = Float.MAX_VALUE;
+
+ final int meanCount = means.size();
+ for (int i = 0; i < meanCount; i++) {
+ Mean next = means.get(i);
+ // We don't need the sqrt when comparing distances in euclidean space
+ // because they exist on both sides of the equation and cancel each other out.
+ float nextDistance = sqDistance(point, next.centroid);
+ if (nextDistance < nearestDistance) {
+ nearest = next;
+ nearestDistance = nextDistance;
+ }
+ }
+ return nearest;
+ }
+
+ /** */
+ @VisibleForTesting
+ public static float sqDistance(float[] a, float[] b) {
+ float dist = 0;
+ final int length = a.length;
+ for (int i = 0; i < length; i++) {
+ dist += (a[i] - b[i]) * (a[i] - b[i]);
+ }
+ return dist;
+ }
+
+ /** Definition of a mean, contains a centroid and points on its cluster. */
+ public static class Mean {
+ float[] centroid;
+ final ArrayList<float[]> closestItems = new ArrayList<>();
+
+ public Mean(int dimension) {
+ centroid = new float[dimension];
}
- /** @param inputData */
- @VisibleForTesting
- public void checkDataSetSanity(float[][] inputData) {
- if (inputData == null) {
- throw new IllegalArgumentException("Data set is null.");
- } else if (inputData.length == 0) {
- throw new IllegalArgumentException("Data set is empty.");
- } else if (inputData[0] == null) {
- throw new IllegalArgumentException("Bad data set format.");
- }
-
- final int dimension = inputData[0].length;
- final int length = inputData.length;
- for (int i = 1; i < length; i++) {
- if (inputData[i] == null || inputData[i].length != dimension) {
- throw new IllegalArgumentException("Bad data set format.");
- }
- }
+ public Mean(float... centroid) {
+ this.centroid = centroid;
}
- /**
- * K-Means iteration.
- *
- * @param means Current means
- * @param inputData Input data
- * @return True if data set converged
- */
- private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
-
- // Clean up the previous state because we need to compute
- // which point belongs to each mean again.
- for (int i = means.size() - 1; i >= 0; i--) {
- final Mean mean = means.get(i);
- mean.mClosestItems.clear();
- }
- for (int i = inputData.length - 1; i >= 0; i--) {
- final float[] current = inputData[i];
- final Mean nearest = nearestMean(current, means);
- nearest.mClosestItems.add(current);
- }
-
- boolean converged = true;
- // Move each mean towards the nearest data set points
- for (int i = means.size() - 1; i >= 0; i--) {
- final Mean mean = means.get(i);
- if (mean.mClosestItems.size() == 0) {
- continue;
- }
-
- // Compute the new mean centroid:
- // 1. Sum all all points
- // 2. Average them
- final float[] oldCentroid = mean.mCentroid;
- mean.mCentroid = new float[oldCentroid.length];
- for (int j = 0; j < mean.mClosestItems.size(); j++) {
- // Update each centroid component
- for (int p = 0; p < mean.mCentroid.length; p++) {
- mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
- }
- }
- for (int j = 0; j < mean.mCentroid.length; j++) {
- mean.mCentroid[j] /= mean.mClosestItems.size();
- }
-
- // We converged if the centroid didn't move for any of the means.
- if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
- converged = false;
- }
- }
- return converged;
+ public float[] getCentroid() {
+ return centroid;
}
- /**
- * @param point
- * @param means
- * @return
- */
- @VisibleForTesting
- public static Mean nearestMean(float[] point, List<Mean> means) {
- Mean nearest = null;
- float nearestDistance = Float.MAX_VALUE;
-
- final int meanCount = means.size();
- for (int i = 0; i < meanCount; i++) {
- Mean next = means.get(i);
- // We don't need the sqrt when comparing distances in euclidean space
- // because they exist on both sides of the equation and cancel each other out.
- float nextDistance = sqDistance(point, next.mCentroid);
- if (nextDistance < nearestDistance) {
- nearest = next;
- nearestDistance = nextDistance;
- }
- }
- return nearest;
+ public List<float[]> getItems() {
+ return closestItems;
}
- /**
- * @param a
- * @param b
- * @return
- */
- @VisibleForTesting
- public static float sqDistance(float[] a, float[] b) {
- float dist = 0;
- final int length = a.length;
- for (int i = 0; i < length; i++) {
- dist += (a[i] - b[i]) * (a[i] - b[i]);
- }
- return dist;
+ @Override
+ public String toString() {
+ return "Mean(centroid: " + Arrays.toString(centroid) + ", size: " + closestItems.size() + ")";
}
-
- /** Definition of a mean, contains a centroid and points on its cluster. */
- public static class Mean {
- float[] mCentroid;
- final ArrayList<float[]> mClosestItems = new ArrayList<>();
-
- public Mean(int dimension) {
- mCentroid = new float[dimension];
- }
-
- public Mean(float... centroid) {
- mCentroid = centroid;
- }
-
- public float[] getCentroid() {
- return mCentroid;
- }
-
- public List<float[]> getItems() {
- return mClosestItems;
- }
-
- @Override
- public String toString() {
- return "Mean(centroid: "
- + Arrays.toString(mCentroid)
- + ", size: "
- + mClosestItems.size()
- + ")";
- }
- }
+ }
}
diff --git a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
index addc8e8..ea96633 100644
--- a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
+++ b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,13 +11,12 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier.utils;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.io.PrintWriter;
/**
@@ -26,52 +25,52 @@
* @see PrintWriter
*/
public final class IndentingPrintWriter {
- private static final String SINGLE_INDENT = " ";
+ private static final String SINGLE_INDENT = " ";
- private final PrintWriter mWriter;
- private StringBuilder mIndentBuilder = new StringBuilder();
- private String mCurrentIndent = "";
+ private final PrintWriter writer;
+ private final StringBuilder indentBuilder = new StringBuilder();
+ private String currentIndent = "";
- public IndentingPrintWriter(PrintWriter writer) {
- mWriter = Preconditions.checkNotNull(writer);
- }
+ public IndentingPrintWriter(PrintWriter writer) {
+ this.writer = Preconditions.checkNotNull(writer);
+ }
- /** Prints a string. */
- public IndentingPrintWriter println(String string) {
- mWriter.print(mCurrentIndent);
- mWriter.print(string);
- mWriter.println();
- return this;
- }
+ /** Prints a string. */
+ public IndentingPrintWriter println(String string) {
+ writer.print(currentIndent);
+ writer.print(string);
+ writer.println();
+ return this;
+ }
- /** Prints a empty line */
- public IndentingPrintWriter println() {
- mWriter.println();
- return this;
- }
+ /** Prints a empty line */
+ public IndentingPrintWriter println() {
+ writer.println();
+ return this;
+ }
- /** Increases indents for subsequent texts. */
- public IndentingPrintWriter increaseIndent() {
- mIndentBuilder.append(SINGLE_INDENT);
- mCurrentIndent = mIndentBuilder.toString();
- return this;
- }
+ /** Increases indents for subsequent texts. */
+ public IndentingPrintWriter increaseIndent() {
+ indentBuilder.append(SINGLE_INDENT);
+ currentIndent = indentBuilder.toString();
+ return this;
+ }
- /** Decreases indents for subsequent texts. */
- public IndentingPrintWriter decreaseIndent() {
- mIndentBuilder.delete(0, SINGLE_INDENT.length());
- mCurrentIndent = mIndentBuilder.toString();
- return this;
- }
+ /** Decreases indents for subsequent texts. */
+ public IndentingPrintWriter decreaseIndent() {
+ indentBuilder.delete(0, SINGLE_INDENT.length());
+ currentIndent = indentBuilder.toString();
+ return this;
+ }
- /** Prints a key-valued pair. */
- public IndentingPrintWriter printPair(String key, Object value) {
- println(String.format("%s=%s", key, String.valueOf(value)));
- return this;
- }
+ /** Prints a key-valued pair. */
+ public IndentingPrintWriter printPair(String key, Object value) {
+ println(String.format("%s=%s", key, String.valueOf(value)));
+ return this;
+ }
- /** Flushes the stream. */
- public void flush() {
- mWriter.flush();
- }
+ /** Flushes the stream. */
+ public void flush() {
+ writer.flush();
+ }
}
diff --git a/java/tests/unittests/Android.bp b/java/tests/instrumentation/Android.bp
similarity index 89%
rename from java/tests/unittests/Android.bp
rename to java/tests/instrumentation/Android.bp
index ab3f713..c742c8d 100644
--- a/java/tests/unittests/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -22,12 +22,15 @@
srcs: [
"src/**/*.java",
],
+ // TODO: Re-enable the ulp tests.
+ exclude_srcs: ["src/**/ulp/*.java"],
static_libs: [
+ "androidx.test.ext.junit",
"androidx.test.rules",
"androidx.test.espresso.core",
"androidx.test.ext.truth",
- "mockito-target-inline",
+ "mockito-target-minus-junit4",
"ub-uiautomator",
"testng",
"compatibility-device-util-axt",
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
new file mode 100644
index 0000000..5de247c
--- /dev/null
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -0,0 +1,14 @@
+<?xml version="1.0" encoding="utf-8"?>
+<manifest xmlns:android="http://schemas.android.com/apk/res/android"
+ package="com.android.textclassifier.tests">
+
+ <uses-sdk android:minSdkVersion="28"/>
+
+ <application>
+ <uses-library android:name="android.test.runner"/>
+ </application>
+
+ <instrumentation
+ android:name="androidx.test.runner.AndroidJUnitRunner"
+ android:targetPackage="com.android.textclassifier.tests"/>
+</manifest>
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ActionsModelParamsSupplierTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ActionsModelParamsSupplierTest.java
new file mode 100644
index 0000000..6a9d08a
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ActionsModelParamsSupplierTest.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.io.File;
+import java.util.Collections;
+import java.util.Locale;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ActionsModelParamsSupplierTest {
+
+ @Test
+ public void getSerializedPreconditions_validActionsModelParams() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ new File("/model/file"),
+ 200 /* version */,
+ Collections.singletonList(Locale.forLanguageTag("en")),
+ "en",
+ false);
+ byte[] serializedPreconditions = new byte[] {0x12, 0x24, 0x36};
+ ActionsModelParamsSupplier.ActionsModelParams params =
+ new ActionsModelParamsSupplier.ActionsModelParams(
+ 200 /* version */, "en", serializedPreconditions);
+
+ byte[] actual = params.getSerializedPreconditions(modelFile);
+
+ assertThat(actual).isEqualTo(serializedPreconditions);
+ }
+
+ @Test
+ public void getSerializedPreconditions_invalidVersion() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ new File("/model/file"),
+ 201 /* version */,
+ Collections.singletonList(Locale.forLanguageTag("en")),
+ "en",
+ false);
+ byte[] serializedPreconditions = new byte[] {0x12, 0x24, 0x36};
+ ActionsModelParamsSupplier.ActionsModelParams params =
+ new ActionsModelParamsSupplier.ActionsModelParams(
+ 200 /* version */, "en", serializedPreconditions);
+
+ byte[] actual = params.getSerializedPreconditions(modelFile);
+
+ assertThat(actual).isNull();
+ }
+
+ @Test
+ public void getSerializedPreconditions_invalidLocales() {
+ final String languageTag = "zh";
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ new File("/model/file"),
+ 200 /* version */,
+ Collections.singletonList(Locale.forLanguageTag(languageTag)),
+ languageTag,
+ false);
+ byte[] serializedPreconditions = new byte[] {0x12, 0x24, 0x36};
+ ActionsModelParamsSupplier.ActionsModelParams params =
+ new ActionsModelParamsSupplier.ActionsModelParams(
+ 200 /* version */, "en", serializedPreconditions);
+
+ byte[] actual = params.getSerializedPreconditions(modelFile);
+
+ assertThat(actual).isNull();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
new file mode 100644
index 0000000..b48d361
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
@@ -0,0 +1,305 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static android.view.textclassifier.ConversationActions.Message.PERSON_USER_OTHERS;
+import static android.view.textclassifier.ConversationActions.Message.PERSON_USER_SELF;
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.PendingIntent;
+import android.app.Person;
+import android.app.RemoteAction;
+import android.content.ComponentName;
+import android.content.Intent;
+import android.graphics.drawable.Icon;
+import android.net.Uri;
+import android.os.Bundle;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.intent.LabeledIntent;
+import com.android.textclassifier.intent.TemplateIntentFactory;
+import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.RemoteActionTemplate;
+import com.google.common.collect.ImmutableList;
+import java.time.Instant;
+import java.time.ZoneId;
+import java.time.ZonedDateTime;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.function.Function;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ActionsSuggestionsHelperTest {
+ private static final String LOCALE_TAG = Locale.US.toLanguageTag();
+ private static final Function<CharSequence, List<String>> LANGUAGE_DETECTOR =
+ charSequence -> Collections.singletonList(LOCALE_TAG);
+
+ @Test
+ public void testToNativeMessages_emptyInput() {
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(ImmutableList.of(), LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).isEmpty();
+ }
+
+ @Test
+ public void testToNativeMessages_noTextMessages() {
+ ConversationActions.Message messageWithoutText =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS).build();
+
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ ImmutableList.of(messageWithoutText), LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).isEmpty();
+ }
+
+ @Test
+ public void testToNativeMessages_userIdEncoding() {
+ Person userA = new Person.Builder().setName("userA").build();
+ Person userB = new Person.Builder().setName("userB").build();
+
+ ConversationActions.Message firstMessage =
+ new ConversationActions.Message.Builder(userB).setText("first").build();
+ ConversationActions.Message secondMessage =
+ new ConversationActions.Message.Builder(userA).setText("second").build();
+ ConversationActions.Message thirdMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_SELF).setText("third").build();
+ ConversationActions.Message fourthMessage =
+ new ConversationActions.Message.Builder(userA).setText("fourth").build();
+
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ Arrays.asList(firstMessage, secondMessage, thirdMessage, fourthMessage),
+ LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).hasLength(4);
+ assertNativeMessage(conversationMessages[0], firstMessage.getText(), 2, 0);
+ assertNativeMessage(conversationMessages[1], secondMessage.getText(), 1, 0);
+ assertNativeMessage(conversationMessages[2], thirdMessage.getText(), 0, 0);
+ assertNativeMessage(conversationMessages[3], fourthMessage.getText(), 1, 0);
+ }
+
+ @Test
+ public void testToNativeMessages_referenceTime() {
+ ConversationActions.Message firstMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS)
+ .setText("first")
+ .setReferenceTime(createZonedDateTimeFromMsUtc(1000))
+ .build();
+ ConversationActions.Message secondMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS).setText("second").build();
+ ConversationActions.Message thirdMessage =
+ new ConversationActions.Message.Builder(PERSON_USER_OTHERS)
+ .setText("third")
+ .setReferenceTime(createZonedDateTimeFromMsUtc(2000))
+ .build();
+
+ ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ Arrays.asList(firstMessage, secondMessage, thirdMessage), LANGUAGE_DETECTOR);
+
+ assertThat(conversationMessages).hasLength(3);
+ assertNativeMessage(conversationMessages[0], firstMessage.getText(), 1, 1000);
+ assertNativeMessage(conversationMessages[1], secondMessage.getText(), 1, 0);
+ assertNativeMessage(conversationMessages[2], thirdMessage.getText(), 1, 2000);
+ }
+
+ @Test
+ public void testDeduplicateActions() {
+ Bundle phoneExtras = new Bundle();
+ Intent phoneIntent = new Intent();
+ phoneIntent.setComponent(new ComponentName("phone", "intent"));
+ ExtrasUtils.putActionIntent(phoneExtras, phoneIntent);
+
+ Bundle anotherPhoneExtras = new Bundle();
+ Intent anotherPhoneIntent = new Intent();
+ anotherPhoneIntent.setComponent(new ComponentName("phone", "another.intent"));
+ ExtrasUtils.putActionIntent(anotherPhoneExtras, anotherPhoneIntent);
+
+ Bundle urlExtras = new Bundle();
+ Intent urlIntent = new Intent();
+ urlIntent.setComponent(new ComponentName("url", "intent"));
+ ExtrasUtils.putActionIntent(urlExtras, urlIntent);
+
+ PendingIntent pendingIntent =
+ PendingIntent.getActivity(ApplicationProvider.getApplicationContext(), 0, phoneIntent, 0);
+ Icon icon = Icon.createWithData(new byte[0], 0, 0);
+ ConversationAction action =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "1", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithSameLabel =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "2", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithSamePackageButDifferentClass =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "3", pendingIntent))
+ .setExtras(anotherPhoneExtras)
+ .build();
+ ConversationAction actionWithDifferentLabel =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "another_label", "4", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithDifferentPackage =
+ new ConversationAction.Builder(ConversationAction.TYPE_OPEN_URL)
+ .setAction(new RemoteAction(icon, "label", "5", pendingIntent))
+ .setExtras(urlExtras)
+ .build();
+ ConversationAction actionWithoutRemoteAction =
+ new ConversationAction.Builder(ConversationAction.TYPE_CREATE_REMINDER).build();
+
+ List<ConversationAction> conversationActions =
+ ActionsSuggestionsHelper.removeActionsWithDuplicates(
+ Arrays.asList(
+ action,
+ actionWithSameLabel,
+ actionWithSamePackageButDifferentClass,
+ actionWithDifferentLabel,
+ actionWithDifferentPackage,
+ actionWithoutRemoteAction));
+
+ assertThat(conversationActions).hasSize(3);
+ assertThat(conversationActions.get(0).getAction().getContentDescription().toString())
+ .isEqualTo("4");
+ assertThat(conversationActions.get(1).getAction().getContentDescription().toString())
+ .isEqualTo("5");
+ assertThat(conversationActions.get(2).getAction()).isNull();
+ }
+
+ @Test
+ public void testDeduplicateActions_nullComponent() {
+ Bundle phoneExtras = new Bundle();
+ Intent phoneIntent = new Intent(Intent.ACTION_DIAL);
+ ExtrasUtils.putActionIntent(phoneExtras, phoneIntent);
+ PendingIntent pendingIntent =
+ PendingIntent.getActivity(ApplicationProvider.getApplicationContext(), 0, phoneIntent, 0);
+ Icon icon = Icon.createWithData(new byte[0], 0, 0);
+ ConversationAction action =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "1", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+ ConversationAction actionWithSameLabel =
+ new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
+ .setAction(new RemoteAction(icon, "label", "2", pendingIntent))
+ .setExtras(phoneExtras)
+ .build();
+
+ List<ConversationAction> conversationActions =
+ ActionsSuggestionsHelper.removeActionsWithDuplicates(
+ Arrays.asList(action, actionWithSameLabel));
+
+ assertThat(conversationActions).isEmpty();
+ }
+
+ public void createLabeledIntentResult_null() {
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
+ new ActionsSuggestionsModel.ActionSuggestion(
+ "text", ConversationAction.TYPE_OPEN_URL, 1.0f, null, null, null);
+
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ ApplicationProvider.getApplicationContext(),
+ new TemplateIntentFactory(),
+ nativeSuggestion);
+
+ assertThat(labeledIntentResult).isNull();
+ }
+
+ @Test
+ public void createLabeledIntentResult_emptyList() {
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
+ new ActionsSuggestionsModel.ActionSuggestion(
+ "text",
+ ConversationAction.TYPE_OPEN_URL,
+ 1.0f,
+ null,
+ null,
+ new RemoteActionTemplate[0]);
+
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ ApplicationProvider.getApplicationContext(),
+ new TemplateIntentFactory(),
+ nativeSuggestion);
+
+ assertThat(labeledIntentResult).isNull();
+ }
+
+ @Test
+ public void createLabeledIntentResult() {
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
+ new ActionsSuggestionsModel.ActionSuggestion(
+ "text",
+ ConversationAction.TYPE_OPEN_URL,
+ 1.0f,
+ null,
+ null,
+ new RemoteActionTemplate[] {
+ new RemoteActionTemplate(
+ "title",
+ null,
+ "description",
+ null,
+ Intent.ACTION_VIEW,
+ Uri.parse("http://www.android.com").toString(),
+ null,
+ 0,
+ null,
+ null,
+ null,
+ 0)
+ });
+
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ ApplicationProvider.getApplicationContext(),
+ new TemplateIntentFactory(),
+ nativeSuggestion);
+
+ assertThat(labeledIntentResult.remoteAction.getTitle().toString()).isEqualTo("title");
+ assertThat(labeledIntentResult.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ }
+
+ private static ZonedDateTime createZonedDateTimeFromMsUtc(long msUtc) {
+ return ZonedDateTime.ofInstant(Instant.ofEpochMilli(msUtc), ZoneId.of("UTC"));
+ }
+
+ private static void assertNativeMessage(
+ ActionsSuggestionsModel.ConversationMessage nativeMessage,
+ CharSequence text,
+ int userId,
+ long referenceTimeInMsUtc) {
+ assertThat(nativeMessage.getText()).isEqualTo(text.toString());
+ assertThat(nativeMessage.getUserId()).isEqualTo(userId);
+ assertThat(nativeMessage.getDetectedTextLanguageTags()).isEqualTo(LOCALE_TAG);
+ assertThat(nativeMessage.getReferenceTimeMsUtc()).isEqualTo(referenceTimeInMsUtc);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/FakeContextBuilder.java b/java/tests/instrumentation/src/com/android/textclassifier/FakeContextBuilder.java
new file mode 100644
index 0000000..17b6e0a
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/FakeContextBuilder.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.ContextWrapper;
+import android.content.Intent;
+import android.content.pm.ActivityInfo;
+import android.content.pm.ApplicationInfo;
+import android.content.pm.PackageManager;
+import android.content.pm.ResolveInfo;
+import androidx.test.core.app.ApplicationProvider;
+import com.google.common.base.Preconditions;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.UUID;
+import javax.annotation.Nullable;
+import org.mockito.stubbing.Answer;
+
+/** A builder used to build a fake context for testing. */
+public final class FakeContextBuilder {
+
+ /** A component name that can be used for tests. */
+ public static final ComponentName DEFAULT_COMPONENT = new ComponentName("pkg", "cls");
+
+ private final PackageManager packageManager;
+ private final ContextWrapper context;
+ private final Map<String, ComponentName> components = new HashMap<>();
+ private final Map<String, CharSequence> appLabels = new HashMap<>();
+ @Nullable private ComponentName allIntentComponent;
+
+ public FakeContextBuilder() {
+ packageManager = mock(PackageManager.class);
+ when(packageManager.resolveActivity(any(Intent.class), anyInt())).thenReturn(null);
+ context =
+ new ContextWrapper(ApplicationProvider.getApplicationContext()) {
+ @Override
+ public PackageManager getPackageManager() {
+ return packageManager;
+ }
+ };
+ }
+
+ /**
+ * Sets the component name of an activity to handle the specified intent action.
+ *
+ * <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
+ */
+ public FakeContextBuilder setIntentComponent(
+ String intentAction, @Nullable ComponentName component) {
+ Preconditions.checkNotNull(intentAction);
+ components.put(intentAction, component);
+ return this;
+ }
+
+ /** Sets the app label res for a specified package. */
+ public FakeContextBuilder setAppLabel(String packageName, @Nullable CharSequence appLabel) {
+ Preconditions.checkNotNull(packageName);
+ appLabels.put(packageName, appLabel);
+ return this;
+ }
+
+ /**
+ * Sets the component name of an activity to handle all intents.
+ *
+ * <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
+ */
+ public FakeContextBuilder setAllIntentComponent(@Nullable ComponentName component) {
+ allIntentComponent = component;
+ return this;
+ }
+
+ /** Builds and returns a fake context. */
+ public Context build() {
+ when(packageManager.resolveActivity(any(Intent.class), anyInt()))
+ .thenAnswer(
+ (Answer<ResolveInfo>)
+ invocation -> {
+ final String action = ((Intent) invocation.getArgument(0)).getAction();
+ final ComponentName component =
+ components.containsKey(action) ? components.get(action) : allIntentComponent;
+ return getResolveInfo(component);
+ });
+ when(packageManager.getApplicationLabel(any(ApplicationInfo.class)))
+ .thenAnswer(
+ (Answer<CharSequence>)
+ invocation -> {
+ ApplicationInfo applicationInfo = invocation.getArgument(0);
+ return appLabels.get(applicationInfo.packageName);
+ });
+ return context;
+ }
+
+ /** Returns a component name with random package and class names. */
+ public static ComponentName newComponent() {
+ return new ComponentName(UUID.randomUUID().toString(), UUID.randomUUID().toString());
+ }
+
+ private static ResolveInfo getResolveInfo(ComponentName component) {
+ final ResolveInfo info;
+ if (component == null) {
+ info = null;
+ } else {
+ // NOTE: If something breaks in TextClassifier because we expect more fields to be set
+ // in here, just add them.
+ info = new ResolveInfo();
+ info.activityInfo = new ActivityInfo();
+ info.activityInfo.packageName = component.getPackageName();
+ info.activityInfo.name = component.getClassName();
+ info.activityInfo.exported = true;
+ info.activityInfo.applicationInfo = new ApplicationInfo();
+ info.activityInfo.applicationInfo.packageName = component.getPackageName();
+ info.activityInfo.applicationInfo.icon = 0;
+ }
+ return info;
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
new file mode 100644
index 0000000..ab4fde4
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -0,0 +1,353 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.google.common.collect.ImmutableList;
+import java.io.File;
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.Locale;
+import java.util.function.Supplier;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ModelFileManagerTest {
+ private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
+ @Mock private Supplier<List<ModelFileManager.ModelFile>> modelFileSupplier;
+ private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
+ private ModelFileManager modelFileManager;
+ private File rootTestDir;
+ private File factoryModelDir;
+ private File updatedModelFile;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ modelFileManager = new ModelFileManager(modelFileSupplier);
+ rootTestDir = ApplicationProvider.getApplicationContext().getCacheDir();
+ factoryModelDir = new File(rootTestDir, "factory");
+ updatedModelFile = new File(rootTestDir, "updated.model");
+
+ modelFileSupplierImpl =
+ new ModelFileManager.ModelFileSupplierImpl(
+ factoryModelDir,
+ "test\\d.model",
+ updatedModelFile,
+ fd -> 1,
+ fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
+
+ rootTestDir.mkdirs();
+ factoryModelDir.mkdirs();
+
+ Locale.setDefault(DEFAULT_LOCALE);
+ }
+
+ @After
+ public void removeTestDir() {
+ recursiveDelete(rootTestDir);
+ }
+
+ @Test
+ public void get() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+ when(modelFileSupplier.get()).thenReturn(ImmutableList.of(modelFile));
+
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles();
+
+ assertThat(modelFiles).hasSize(1);
+ assertThat(modelFiles.get(0)).isEqualTo(modelFile);
+ }
+
+ @Test
+ public void findBestModel_versionCode() {
+ ModelFileManager.ModelFile olderModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+
+ ModelFileManager.ModelFile newerModelFile =
+ new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
+ when(modelFileSupplier.get()).thenReturn(Arrays.asList(olderModelFile, newerModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
+
+ assertThat(bestModelFile).isEqualTo(newerModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageDependentModelIsPreferred() {
+ Locale locale = Locale.forLanguageTag("ja");
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(locale),
+ locale.toLanguageTag(),
+ false);
+ when(modelFileSupplier.get())
+ .thenReturn(Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags(locale.toLanguageTag()));
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_noMatchedLanguageModel() {
+ Locale locale = Locale.forLanguageTag("ja");
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, Collections.emptyList(), "", true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(locale),
+ locale.toLanguageTag(),
+ false);
+
+ when(modelFileSupplier.get())
+ .thenReturn(Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ new ModelFileManager.ModelFile(new File("/path/a"), 1, ImmutableList.of(), "", true);
+
+ ModelFileManager.ModelFile languageDependentModelFile =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(DEFAULT_LOCALE),
+ DEFAULT_LOCALE.toLanguageTag(),
+ false);
+
+ when(modelFileSupplier.get())
+ .thenReturn(Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion() {
+ ModelFileManager.ModelFile matchButOlderModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("fr")),
+ "fr",
+ false);
+
+ ModelFileManager.ModelFile mismatchButNewerModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 2,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ when(modelFileSupplier.get())
+ .thenReturn(Arrays.asList(matchButOlderModel, mismatchButNewerModel));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("fr"));
+ assertThat(bestModelFile).isEqualTo(matchButOlderModel);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
+ ModelFileManager.ModelFile matchLocaleModel =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile languageIndependentModel =
+ new ModelFileManager.ModelFile(new File("/path/a"), 2, ImmutableList.of(), "", true);
+ when(modelFileSupplier.get())
+ .thenReturn(Arrays.asList(matchLocaleModel, languageIndependentModel));
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(LocaleList.forLanguageTags("ja"));
+
+ assertThat(bestModelFile).isEqualTo(matchLocaleModel);
+ }
+
+ @Test
+ public void modelFileEquals() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA).isEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_different() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ new File("/path/b"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA).isNotEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_getPath() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA.getPath()).isEqualTo("/path/a");
+ }
+
+ @Test
+ public void modelFile_getName() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ assertThat(modelA.getName()).isEqualTo("a");
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_languageDependentIsBetter() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 1,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(new File("/path/b"), 2, ImmutableList.of(), "", true);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_version() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ new File("/path/a"),
+ 2,
+ Collections.singletonList(Locale.forLanguageTag("ja")),
+ "ja",
+ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(new File("/path/b"), 1, Collections.emptyList(), "", false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void testFileSupplierImpl_updatedFileOnly() throws IOException {
+ updatedModelFile.createNewFile();
+ File model1 = new File(factoryModelDir, "test1.model");
+ model1.createNewFile();
+ File model2 = new File(factoryModelDir, "test2.model");
+ model2.createNewFile();
+ new File(factoryModelDir, "not_match_regex.model").createNewFile();
+
+ List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+ List<String> modelFilePaths =
+ modelFiles.stream().map(modelFile -> modelFile.getPath()).collect(Collectors.toList());
+
+ assertThat(modelFiles).hasSize(3);
+ assertThat(modelFilePaths)
+ .containsExactly(
+ updatedModelFile.getAbsolutePath(), model1.getAbsolutePath(), model2.getAbsolutePath());
+ }
+
+ @Test
+ public void testFileSupplierImpl_empty() {
+ factoryModelDir.delete();
+ List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+
+ assertThat(modelFiles).hasSize(0);
+ }
+
+ private static void recursiveDelete(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ recursiveDelete(innerFile);
+ }
+ }
+ f.delete();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/StringUtilsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/StringUtilsTest.java
new file mode 100644
index 0000000..7511e45
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/StringUtilsTest.java
@@ -0,0 +1,90 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.testng.Assert.assertThrows;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class StringUtilsTest {
+
+ @Test
+ public void testGetSubString() {
+ final String text = "Yakuza call themselves 任侠団体";
+ int start;
+ int end;
+ int minimumLength;
+
+ // End index at end of text.
+ start = text.indexOf("任侠団体");
+ end = text.length();
+ minimumLength = 20;
+ assertThat(StringUtils.getSubString(text, start, end, minimumLength))
+ .isEqualTo("call themselves 任侠団体");
+
+ // Start index at beginning of text.
+ start = 0;
+ end = "Yakuza".length();
+ minimumLength = 15;
+ assertThat(StringUtils.getSubString(text, start, end, minimumLength))
+ .isEqualTo("Yakuza call themselves");
+
+ // Text in the middle
+ start = text.indexOf("all");
+ end = start + 1;
+ minimumLength = 10;
+ assertThat(StringUtils.getSubString(text, start, end, minimumLength))
+ .isEqualTo("Yakuza call themselves");
+
+ // Selection >= minimumLength.
+ start = text.indexOf("themselves");
+ end = start + "themselves".length();
+ minimumLength = end - start;
+ assertThat(StringUtils.getSubString(text, start, end, minimumLength)).isEqualTo("themselves");
+
+ // text.length < minimumLength.
+ minimumLength = text.length() + 1;
+ assertThat(StringUtils.getSubString(text, start, end, minimumLength)).isEqualTo(text);
+ }
+
+ @Test
+ public void testGetSubString_invalidParams() {
+ final String text = "The Yoruba regard Olodumare as the principal agent of creation";
+ final int length = text.length();
+ final int minimumLength = 10;
+
+ // Null text
+ assertThrows(
+ NullPointerException.class, () -> StringUtils.getSubString(null, 0, 1, minimumLength));
+ // start > end
+ assertThrows(
+ IllegalArgumentException.class, () -> StringUtils.getSubString(text, 6, 5, minimumLength));
+ // start < 0
+ assertThrows(
+ IllegalArgumentException.class, () -> StringUtils.getSubString(text, -1, 5, minimumLength));
+ // end > text.length
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> StringUtils.getSubString(text, 6, length + 1, minimumLength));
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassificationConstantsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassificationConstantsTest.java
new file mode 100644
index 0000000..6ec2d2c
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassificationConstantsTest.java
@@ -0,0 +1,104 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertWithMessage;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassificationConstantsTest {
+
+ private static final float EPSILON = 0.0001f;
+
+ @Test
+ public void testLoadFromString_defaultValues() {
+ final TextClassificationConstants constants = new TextClassificationConstants();
+
+ assertWithMessage("suggest_selection_max_range_length")
+ .that(constants.getSuggestSelectionMaxRangeLength())
+ .isEqualTo(10 * 1000);
+ assertWithMessage("classify_text_max_range_length")
+ .that(constants.getClassifyTextMaxRangeLength())
+ .isEqualTo(10 * 1000);
+ assertWithMessage("generate_links_max_text_length")
+ .that(constants.getGenerateLinksMaxTextLength())
+ .isEqualTo(100 * 1000);
+ // assertWithMessage("generate_links_log_sample_rate")
+ // .that(constants.getGenerateLinksLogSampleRate()).isEqualTo(100);
+ assertWithMessage("entity_list_default")
+ .that(constants.getEntityListDefault())
+ .containsExactly("address", "email", "url", "phone", "date", "datetime", "flight");
+ assertWithMessage("entity_list_not_editable")
+ .that(constants.getEntityListNotEditable())
+ .containsExactly("address", "email", "url", "phone", "date", "datetime", "flight");
+ assertWithMessage("entity_list_editable")
+ .that(constants.getEntityListEditable())
+ .containsExactly("address", "email", "url", "phone", "date", "datetime", "flight");
+ assertWithMessage("in_app_conversation_action_types_default")
+ .that(constants.getInAppConversationActionTypes())
+ .containsExactly(
+ "text_reply",
+ "create_reminder",
+ "call_phone",
+ "open_url",
+ "send_email",
+ "send_sms",
+ "track_flight",
+ "view_calendar",
+ "view_map",
+ "add_contact",
+ "copy");
+ assertWithMessage("notification_conversation_action_types_default")
+ .that(constants.getNotificationConversationActionTypes())
+ .containsExactly(
+ "text_reply",
+ "create_reminder",
+ "call_phone",
+ "open_url",
+ "send_email",
+ "send_sms",
+ "track_flight",
+ "view_calendar",
+ "view_map",
+ "add_contact",
+ "copy");
+ assertWithMessage("lang_id_threshold_override")
+ .that(constants.getLangIdThresholdOverride())
+ .isWithin(EPSILON)
+ .of(-1f);
+ Assert.assertArrayEquals(
+ "lang_id_context_settings",
+ constants.getLangIdContextSettings(),
+ new float[] {20, 1, 0.4f},
+ EPSILON);
+ assertWithMessage("detect_language_from_text_enabled")
+ .that(constants.isDetectLanguagesFromTextEnabled())
+ .isTrue();
+ assertWithMessage("template_intent_factory_enabled")
+ .that(constants.isTemplateIntentFactoryEnabled())
+ .isTrue();
+ assertWithMessage("translate_in_classification_enabled")
+ .that(constants.isTranslateInClassificationEnabled())
+ .isTrue();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
new file mode 100644
index 0000000..2ebfa6b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -0,0 +1,653 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static org.hamcrest.CoreMatchers.not;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.testng.Assert.assertThrows;
+
+import android.app.RemoteAction;
+import android.content.Context;
+import android.content.Intent;
+import android.net.Uri;
+import android.os.Bundle;
+import android.os.LocaleList;
+import android.text.Spannable;
+import android.text.SpannableString;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.google.common.collect.ImmutableList;
+import com.google.common.truth.Truth;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+import org.hamcrest.Matcher;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+/**
+ * Testing {@link TextClassifierImplTest} APIs on local and system textclassifier.
+ *
+ * <p>Tests are skipped if such a textclassifier does not exist.
+ */
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierImplTest {
+
+ // TODO: Implement TextClassifierService testing.
+ private static final String TYPE_COPY = "copy";
+ private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
+ private static final String NO_TYPE = null;
+
+ private TextClassifierImpl classifier;
+
+ @Before
+ public void setup() {
+ Context context = ApplicationProvider.getApplicationContext();
+ classifier = new TextClassifierImpl(context, new TextClassificationConstants());
+ }
+
+ @Test
+ public void testSuggestSelection() {
+ String text = "Contact me at droid@android.com";
+ String selected = "droid";
+ String suggested = "droid@android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(request);
+ assertThat(
+ selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
+ }
+
+ @Test
+ public void testSuggestSelection_url() {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(request);
+ assertThat(selection, isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
+ }
+
+ @Test
+ public void testSmartSelection_withEmoji() {
+ String text = "\uD83D\uDE02 Hello.";
+ String selected = "Hello";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextSelection selection = classifier.suggestSelection(request);
+ assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
+ }
+
+ @Test
+ public void testClassifyText() {
+ String text = "Contact me at droid@android.com";
+ String classifiedText = "droid@android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
+ }
+
+ @Test
+ public void testClassifyText_url() {
+ String text = "Visit www.android.com for more information";
+ String classifiedText = "www.android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
+ assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
+ }
+
+ @Test
+ public void testClassifyText_address() {
+ String text = "Brandschenkestrasse 110, Zürich, Switzerland";
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, 0, text.length())
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
+ }
+
+ @Test
+ public void testClassifyText_url_inCaps() {
+ String text = "Visit HTTP://ANDROID.COM for more information";
+ String classifiedText = "HTTP://ANDROID.COM";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
+ assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
+ }
+
+ @Test
+ public void testClassifyText_date() {
+ String text = "Let's meet on January 9, 2018.";
+ String classifiedText = "January 9, 2018";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
+ Bundle extras = classification.getExtras();
+ List<Bundle> entities = ExtrasUtils.getEntities(extras);
+ Truth.assertThat(entities).hasSize(1);
+ Bundle entity = entities.get(0);
+ Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_DATE);
+ }
+
+ @Test
+ public void testClassifyText_datetime() {
+ String text = "Let's meet 2018/01/01 10:30:20.";
+ String classifiedText = "2018/01/01 10:30:20";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex)
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = classifier.classifyText(request);
+ assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
+ }
+
+ @Test
+ public void testClassifyText_foreignText() {
+ LocaleList originalLocales = LocaleList.getDefault();
+ LocaleList.setDefault(LocaleList.forLanguageTags("en"));
+ String japaneseText = "これは日本語のテキストです";
+
+ Context context =
+ new FakeContextBuilder()
+ .setIntentComponent(Intent.ACTION_TRANSLATE, FakeContextBuilder.DEFAULT_COMPONENT)
+ .build();
+ TextClassifierImpl textClassifier =
+ new TextClassifierImpl(context, new TextClassificationConstants());
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length())
+ .setDefaultLocales(LOCALES)
+ .build();
+
+ TextClassification classification = textClassifier.classifyText(request);
+ RemoteAction translateAction = classification.getActions().get(0);
+ assertEquals(1, classification.getActions().size());
+ assertEquals(
+ context.getString(com.android.textclassifier.R.string.translate),
+ translateAction.getTitle().toString());
+
+ assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
+ Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
+ assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
+ Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
+ assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
+ assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
+ assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
+ assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
+ assertEquals("ja", ExtrasUtils.getTopLanguage(intent).getLanguage());
+
+ LocaleList.setDefault(originalLocales);
+ }
+
+ @Test
+ public void testGenerateLinks_phone() {
+ String text = "The number is +12122537077. See you tonight!";
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+ assertThat(
+ classifier.generateLinks(request),
+ isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
+ }
+
+ @Test
+ public void testGenerateLinks_exclude() {
+ String text = "You want apple@banana.com. See you tonight!";
+ List<String> hints = ImmutableList.of();
+ List<String> included = ImmutableList.of();
+ List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ TextLinks.Request request =
+ new TextLinks.Request.Builder(text)
+ .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
+ .setDefaultLocales(LOCALES)
+ .build();
+ assertThat(
+ classifier.generateLinks(request),
+ not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+ }
+
+ @Test
+ public void testGenerateLinks_explicit_address() {
+ String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
+ List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
+ TextLinks.Request request =
+ new TextLinks.Request.Builder(text)
+ .setEntityConfig(TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
+ .setDefaultLocales(LOCALES)
+ .build();
+ assertThat(
+ classifier.generateLinks(request),
+ isTextLinksContaining(
+ text, "1600 Amphitheater Parkway, Mountain View, CA", TextClassifier.TYPE_ADDRESS));
+ }
+
+ @Test
+ public void testGenerateLinks_exclude_override() {
+ String text = "You want apple@banana.com. See you tonight!";
+ List<String> hints = ImmutableList.of();
+ List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
+ TextLinks.Request request =
+ new TextLinks.Request.Builder(text)
+ .setEntityConfig(TextClassifier.EntityConfig.create(hints, included, excluded))
+ .setDefaultLocales(LOCALES)
+ .build();
+ assertThat(
+ classifier.generateLinks(request),
+ not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
+ }
+
+ @Test
+ public void testGenerateLinks_maxLength() {
+ char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
+ Arrays.fill(manySpaces, ' ');
+ TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
+ TextLinks links = classifier.generateLinks(request);
+ assertTrue(links.getLinks().isEmpty());
+ }
+
+ @Test
+ public void testApplyLinks_unsupportedCharacter() {
+ Spannable url = new SpannableString("\u202Emoc.diordna.com");
+ TextLinks.Request request = new TextLinks.Request.Builder(url).build();
+ assertEquals(
+ TextLinks.STATUS_UNSUPPORTED_CHARACTER,
+ classifier.generateLinks(request).apply(url, 0, null));
+ }
+
+ @Test
+ public void testGenerateLinks_tooLong() {
+ char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
+ Arrays.fill(manySpaces, ' ');
+ TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
+ assertThrows(IllegalArgumentException.class, () -> classifier.generateLinks(request));
+ }
+
+ @Test
+ public void testGenerateLinks_entityData() {
+ String text = "The number is +12122537077.";
+ Bundle extras = new Bundle();
+ ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
+ TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
+
+ TextLinks textLinks = classifier.generateLinks(request);
+
+ Truth.assertThat(textLinks.getLinks()).hasSize(1);
+ TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
+ List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
+ Truth.assertThat(entities).hasSize(1);
+ Bundle entity = entities.get(0);
+ Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
+ }
+
+ @Test
+ public void testGenerateLinks_entityData_disabled() {
+ String text = "The number is +12122537077.";
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ TextLinks textLinks = classifier.generateLinks(request);
+
+ Truth.assertThat(textLinks.getLinks()).hasSize(1);
+ TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
+ List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
+ Truth.assertThat(entities).isNull();
+ }
+
+ @Test
+ public void testDetectLanguage() {
+ String text = "This is English text";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+ TextLanguage textLanguage = classifier.detectLanguage(request);
+ assertThat(textLanguage, isTextLanguage("en"));
+ }
+
+ @Test
+ public void testDetectLanguage_japanese() {
+ String text = "これは日本語のテキストです";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+ TextLanguage textLanguage = classifier.detectLanguage(request);
+ assertThat(textLanguage, isTextLanguage("ja"));
+ }
+
+ @Ignore // Doesn't work without a language-based model.
+ @Test
+ public void testSuggestConversationActions_textReplyOnly_maxOne() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Where are you?")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(1)
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ Truth.assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
+ Truth.assertThat(conversationAction.getTextReply()).isNotNull();
+ }
+
+ @Ignore // Doesn't work without a language-based model.
+ @Test
+ public void testSuggestConversationActions_textReplyOnly_noMax() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Where are you?")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ assertTrue(conversationActions.getConversationActions().size() > 1);
+ for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
+ assertThat(conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
+ }
+ }
+
+ @Test
+ public void testSuggestConversationActions_openUrl() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Check this out: https://www.android.com")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(1)
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ Truth.assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
+ Truth.assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
+ Truth.assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
+ }
+
+ @Ignore // Doesn't work without a language-based model.
+ @Test
+ public void testSuggestConversationActions_copy() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Authentication code: 12345")
+ .build();
+ TextClassifier.EntityConfig typeConfig =
+ new TextClassifier.EntityConfig.Builder()
+ .includeTypesFromTextClassifier(false)
+ .setIncludedTypes(Collections.singletonList(TYPE_COPY))
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(1)
+ .setTypeConfig(typeConfig)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+ Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ Truth.assertThat(conversationAction.getType()).isEqualTo(TYPE_COPY);
+ Truth.assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
+ Truth.assertThat(conversationAction.getAction()).isNull();
+ String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
+ Truth.assertThat(code).isEqualTo("12345");
+ Truth.assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras()))
+ .isNotEmpty();
+ }
+
+ @Test
+ public void testSuggetsConversationActions_deduplicate() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("a@android.com b@android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(Collections.singletonList(message))
+ .setMaxSuggestions(3)
+ .build();
+
+ ConversationActions conversationActions = classifier.suggestConversationActions(request);
+
+ Truth.assertThat(conversationActions.getConversationActions()).isEmpty();
+ }
+
+ private static Matcher<TextSelection> isTextSelection(
+ final int startIndex, final int endIndex, final String type) {
+ return new BaseMatcher<TextSelection>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextSelection) {
+ TextSelection selection = (TextSelection) o;
+ return startIndex == selection.getSelectionStartIndex()
+ && endIndex == selection.getSelectionEndIndex()
+ && typeMatches(selection, type);
+ }
+ return false;
+ }
+
+ private boolean typeMatches(TextSelection selection, String type) {
+ return type == null
+ || (selection.getEntityCount() > 0
+ && type.trim().equalsIgnoreCase(selection.getEntity(0)));
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type));
+ }
+ };
+ }
+
+ private static Matcher<TextLinks> isTextLinksContaining(
+ final String text, final String substring, final String type) {
+ return new BaseMatcher<TextLinks>() {
+
+ @Override
+ public void describeTo(Description description) {
+ description
+ .appendText("text=")
+ .appendValue(text)
+ .appendText(", substring=")
+ .appendValue(substring)
+ .appendText(", type=")
+ .appendValue(type);
+ }
+
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextLinks) {
+ for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
+ if (text.subSequence(link.getStart(), link.getEnd()).toString().equals(substring)) {
+ return type.equals(link.getEntity(0));
+ }
+ }
+ }
+ return false;
+ }
+ };
+ }
+
+ private static Matcher<TextClassification> isTextClassification(
+ final String text, final String type) {
+ return new BaseMatcher<TextClassification>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextClassification) {
+ TextClassification result = (TextClassification) o;
+ return text.equals(result.getText())
+ && result.getEntityCount() > 0
+ && type.equals(result.getEntity(0));
+ }
+ return false;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("text=").appendValue(text).appendText(", type=").appendValue(type);
+ }
+ };
+ }
+
+ private static Matcher<TextClassification> containsIntentWithAction(final String action) {
+ return new BaseMatcher<TextClassification>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextClassification) {
+ TextClassification result = (TextClassification) o;
+ return ExtrasUtils.findAction(result, action) != null;
+ }
+ return false;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("intent action=").appendValue(action);
+ }
+ };
+ }
+
+ private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
+ return new BaseMatcher<TextLanguage>() {
+ @Override
+ public boolean matches(Object o) {
+ if (o instanceof TextLanguage) {
+ TextLanguage result = (TextLanguage) o;
+ return result.getLocaleHypothesisCount() > 0
+ && languageTag.equals(result.getLocale(0).toLanguageTag());
+ }
+ return false;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("locale=").appendValue(languageTag);
+ }
+ };
+ }
+
+ private static Matcher<ConversationAction> isConversationAction(String actionType) {
+ return new BaseMatcher<ConversationAction>() {
+ @Override
+ public boolean matches(Object o) {
+ if (!(o instanceof ConversationAction)) {
+ return false;
+ }
+ ConversationAction conversationAction = (ConversationAction) o;
+ if (!actionType.equals(conversationAction.getType())) {
+ return false;
+ }
+ if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
+ if (conversationAction.getTextReply() == null) {
+ return false;
+ }
+ }
+ if (conversationAction.getConfidenceScore() < 0
+ || conversationAction.getConfidenceScore() > 1) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public void describeTo(Description description) {
+ description.appendText("actionType=").appendValue(actionType);
+ }
+ };
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/LabeledIntentTest.java b/java/tests/instrumentation/src/com/android/textclassifier/intent/LabeledIntentTest.java
new file mode 100644
index 0000000..3840823
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/intent/LabeledIntentTest.java
@@ -0,0 +1,166 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.intent;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.testng.Assert.assertThrows;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.net.Uri;
+import android.os.Bundle;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.FakeContextBuilder;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class LabeledIntentTest {
+ private static final String TITLE_WITHOUT_ENTITY = "Map";
+ private static final String TITLE_WITH_ENTITY = "Map NW14D1";
+ private static final String DESCRIPTION = "Check the map";
+ private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open map";
+ private static final Intent INTENT =
+ new Intent(Intent.ACTION_VIEW).setDataAndNormalize(Uri.parse("http://www.android.com"));
+ private static final int REQUEST_CODE = 42;
+ private static final Bundle TEXT_LANGUAGES_BUNDLE = Bundle.EMPTY;
+ private static final String APP_LABEL = "fake";
+
+ private Context context;
+
+ @Before
+ public void setup() {
+ final ComponentName component = FakeContextBuilder.DEFAULT_COMPONENT;
+ context =
+ new FakeContextBuilder()
+ .setIntentComponent(Intent.ACTION_VIEW, component)
+ .setAppLabel(component.getPackageName(), APP_LABEL)
+ .build();
+ }
+
+ @Test
+ public void resolve_preferTitleWithEntity() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY, TITLE_WITH_ENTITY, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result =
+ labeledIntent.resolve(context, /*titleChooser*/ null, TEXT_LANGUAGES_BUNDLE);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ assertThat(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)).isTrue();
+ }
+
+ @Test
+ public void resolve_useAvailableTitle() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result =
+ labeledIntent.resolve(context, /*titleChooser*/ null, TEXT_LANGUAGES_BUNDLE);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ }
+
+ @Test
+ public void resolve_titleChooser() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result =
+ labeledIntent.resolve(
+ context, (labeledIntent1, resolveInfo) -> "chooser", TEXT_LANGUAGES_BUNDLE);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo("chooser");
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ }
+
+ @Test
+ public void resolve_titleChooserReturnsNull() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
+
+ LabeledIntent.Result result =
+ labeledIntent.resolve(
+ context, (labeledIntent1, resolveInfo) -> null, TEXT_LANGUAGES_BUNDLE);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getTitle().toString()).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(result.remoteAction.getContentDescription().toString()).isEqualTo(DESCRIPTION);
+ Intent intent = result.resolvedIntent;
+ assertThat(intent.getAction()).isEqualTo(intent.getAction());
+ assertThat(intent.getComponent()).isNotNull();
+ }
+
+ @Test
+ public void resolve_missingTitle() {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> new LabeledIntent(null, null, DESCRIPTION, null, INTENT, REQUEST_CODE));
+ }
+
+ @Test
+ public void resolve_noIntentHandler() {
+ // See setup(). context can only resolve Intent.ACTION_VIEW.
+ Intent unresolvableIntent = new Intent(Intent.ACTION_TRANSLATE);
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, unresolvableIntent, REQUEST_CODE);
+
+ LabeledIntent.Result result = labeledIntent.resolve(context, null, null);
+
+ assertThat(result).isNull();
+ }
+
+ @Test
+ public void resolve_descriptionWithAppName() {
+ LabeledIntent labeledIntent =
+ new LabeledIntent(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ INTENT,
+ REQUEST_CODE);
+
+ LabeledIntent.Result result =
+ labeledIntent.resolve(context, /*titleChooser*/ null, TEXT_LANGUAGES_BUNDLE);
+
+ assertThat(result).isNotNull();
+ assertThat(result.remoteAction.getContentDescription().toString())
+ .isEqualTo("Use fake to open map");
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
new file mode 100644
index 0000000..389c98e
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
@@ -0,0 +1,120 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.intent;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Intent;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.google.android.textclassifier.AnnotatorModel;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class LegacyIntentClassificationFactoryTest {
+
+ private static final String TEXT = "text";
+ private static final String TYPE_DICTIONARY = "dictionary";
+
+ private LegacyClassificationIntentFactory legacyIntentClassificationFactory;
+
+ @Before
+ public void setup() {
+ legacyIntentClassificationFactory = new LegacyClassificationIntentFactory();
+ }
+
+ @Test
+ public void create_typeDictionary() {
+ AnnotatorModel.ClassificationResult classificationResult =
+ new AnnotatorModel.ClassificationResult(
+ TYPE_DICTIONARY,
+ 1.0f,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ 0L,
+ 0L,
+ 0d);
+
+ List<LabeledIntent> intents =
+ legacyIntentClassificationFactory.create(
+ ApplicationProvider.getApplicationContext(),
+ TEXT,
+ /* foreignText */ false,
+ null,
+ classificationResult);
+
+ assertThat(intents).hasSize(1);
+ LabeledIntent labeledIntent = intents.get(0);
+ Intent intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(Intent.ACTION_DEFINE);
+ assertThat(intent.getStringExtra(Intent.EXTRA_TEXT)).isEqualTo(TEXT);
+ }
+
+ @Test
+ public void create_translateAndDictionary() {
+ AnnotatorModel.ClassificationResult classificationResult =
+ new AnnotatorModel.ClassificationResult(
+ TYPE_DICTIONARY,
+ 1.0f,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ 0L,
+ 0L,
+ 0d);
+
+ List<LabeledIntent> intents =
+ legacyIntentClassificationFactory.create(
+ ApplicationProvider.getApplicationContext(),
+ TEXT,
+ /* foreignText */ true,
+ null,
+ classificationResult);
+
+ assertThat(intents).hasSize(2);
+ assertThat(intents.get(0).intent.getAction()).isEqualTo(Intent.ACTION_DEFINE);
+ assertThat(intents.get(1).intent.getAction()).isEqualTo(Intent.ACTION_TRANSLATE);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
new file mode 100644
index 0000000..42176bd
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
@@ -0,0 +1,240 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.intent;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+
+import android.content.Context;
+import android.content.Intent;
+import android.view.textclassifier.TextClassifier;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.google.android.textclassifier.AnnotatorModel;
+import com.google.android.textclassifier.RemoteActionTemplate;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TemplateClassificationIntentFactoryTest {
+
+ private static final String TEXT = "text";
+ private static final String TITLE_WITHOUT_ENTITY = "Map";
+ private static final String DESCRIPTION = "Opens in Maps";
+ private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open Map";
+ private static final String ACTION = Intent.ACTION_VIEW;
+
+ @Mock private ClassificationIntentFactory fallback;
+ private TemplateClassificationIntentFactory templateClassificationIntentFactory;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ templateClassificationIntentFactory =
+ new TemplateClassificationIntentFactory(new TemplateIntentFactory(), fallback);
+ }
+
+ @Test
+ public void create_foreignText() {
+ AnnotatorModel.ClassificationResult classificationResult =
+ new AnnotatorModel.ClassificationResult(
+ TextClassifier.TYPE_ADDRESS,
+ 1.0f,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ createRemoteActionTemplates(),
+ 0L,
+ 0L,
+ 0d);
+
+ List<LabeledIntent> intents =
+ templateClassificationIntentFactory.create(
+ ApplicationProvider.getApplicationContext(),
+ TEXT,
+ /* foreignText */ true,
+ null,
+ classificationResult);
+
+ assertThat(intents).hasSize(2);
+ LabeledIntent labeledIntent = intents.get(0);
+ assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
+ Intent intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(ACTION);
+
+ labeledIntent = intents.get(1);
+ intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(Intent.ACTION_TRANSLATE);
+ }
+
+ @Test
+ public void create_notForeignText() {
+ AnnotatorModel.ClassificationResult classificationResult =
+ new AnnotatorModel.ClassificationResult(
+ TextClassifier.TYPE_ADDRESS,
+ 1.0f,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ createRemoteActionTemplates(),
+ 0L,
+ 0L,
+ 0d);
+
+ List<LabeledIntent> intents =
+ templateClassificationIntentFactory.create(
+ ApplicationProvider.getApplicationContext(),
+ TEXT,
+ /* foreignText */ false,
+ null,
+ classificationResult);
+
+ assertThat(intents).hasSize(1);
+ LabeledIntent labeledIntent = intents.get(0);
+ assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
+ Intent intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(ACTION);
+ }
+
+ @Test
+ public void create_nullTemplate() {
+ AnnotatorModel.ClassificationResult classificationResult =
+ new AnnotatorModel.ClassificationResult(
+ TextClassifier.TYPE_ADDRESS,
+ 1.0f,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ 0L,
+ 0L,
+ 0d);
+
+ templateClassificationIntentFactory.create(
+ ApplicationProvider.getApplicationContext(),
+ TEXT,
+ /* foreignText */ false,
+ null,
+ classificationResult);
+
+ verify(fallback)
+ .create(
+ same(ApplicationProvider.getApplicationContext()),
+ eq(TEXT),
+ eq(false),
+ eq(null),
+ same(classificationResult));
+ }
+
+ @Test
+ public void create_emptyResult() {
+ AnnotatorModel.ClassificationResult classificationResult =
+ new AnnotatorModel.ClassificationResult(
+ TextClassifier.TYPE_ADDRESS,
+ 1.0f,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ new RemoteActionTemplate[0],
+ 0L,
+ 0L,
+ 0d);
+
+ templateClassificationIntentFactory.create(
+ ApplicationProvider.getApplicationContext(),
+ TEXT,
+ /* foreignText */ false,
+ null,
+ classificationResult);
+
+ verify(fallback, never())
+ .create(
+ any(Context.class),
+ eq(TEXT),
+ eq(false),
+ eq(null),
+ any(AnnotatorModel.ClassificationResult.class));
+ }
+
+ private static RemoteActionTemplate[] createRemoteActionTemplates() {
+ return new RemoteActionTemplate[] {
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ null,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null)
+ };
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java b/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java
new file mode 100644
index 0000000..ee45f18
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java
@@ -0,0 +1,265 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.intent;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Intent;
+import android.net.Uri;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.google.android.textclassifier.NamedVariant;
+import com.google.android.textclassifier.RemoteActionTemplate;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TemplateIntentFactoryTest {
+
+ private static final String TITLE_WITHOUT_ENTITY = "Map";
+ private static final String TITLE_WITH_ENTITY = "Map NW14D1";
+ private static final String DESCRIPTION = "Check the map";
+ private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open map";
+ private static final String ACTION = Intent.ACTION_VIEW;
+ private static final String DATA = Uri.parse("http://www.android.com").toString();
+ private static final String TYPE = "text/html";
+ private static final Integer FLAG = Intent.FLAG_ACTIVITY_NEW_TASK;
+ private static final String[] CATEGORY =
+ new String[] {Intent.CATEGORY_DEFAULT, Intent.CATEGORY_APP_BROWSER};
+ private static final String PACKAGE_NAME = "pkg.name";
+ private static final String KEY_ONE = "key1";
+ private static final String VALUE_ONE = "value1";
+ private static final String KEY_TWO = "key2";
+ private static final int VALUE_TWO = 42;
+
+ private static final NamedVariant[] NAMED_VARIANTS =
+ new NamedVariant[] {
+ new NamedVariant(KEY_ONE, VALUE_ONE), new NamedVariant(KEY_TWO, VALUE_TWO)
+ };
+ private static final Integer REQUEST_CODE = 10;
+
+ private TemplateIntentFactory templateIntentFactory;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ templateIntentFactory = new TemplateIntentFactory();
+ }
+
+ @Test
+ public void create_full() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ DATA,
+ TYPE,
+ FLAG,
+ CATEGORY,
+ /* packageName */ null,
+ NAMED_VARIANTS,
+ REQUEST_CODE);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).hasSize(1);
+ LabeledIntent labeledIntent = intents.get(0);
+ assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(labeledIntent.titleWithEntity).isEqualTo(TITLE_WITH_ENTITY);
+ assertThat(labeledIntent.description).isEqualTo(DESCRIPTION);
+ assertThat(labeledIntent.descriptionWithAppName).isEqualTo(DESCRIPTION_WITH_APP_NAME);
+ assertThat(labeledIntent.requestCode).isEqualTo(REQUEST_CODE);
+ Intent intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(ACTION);
+ assertThat(intent.getData().toString()).isEqualTo(DATA);
+ assertThat(intent.getType()).isEqualTo(TYPE);
+ assertThat(intent.getFlags()).isEqualTo(FLAG);
+ assertThat(intent.getCategories()).containsExactly((Object[]) CATEGORY);
+ assertThat(intent.getPackage()).isNull();
+ assertThat(intent.getStringExtra(KEY_ONE)).isEqualTo(VALUE_ONE);
+ assertThat(intent.getIntExtra(KEY_TWO, 0)).isEqualTo(VALUE_TWO);
+ }
+
+ @Test
+ public void normalizesScheme() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ "HTTp://www.android.com",
+ TYPE,
+ FLAG,
+ CATEGORY,
+ /* packageName */ null,
+ NAMED_VARIANTS,
+ REQUEST_CODE);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ String data = intents.get(0).intent.getData().toString();
+ assertThat(data).isEqualTo("http://www.android.com");
+ }
+
+ @Test
+ public void create_minimal() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ null,
+ DESCRIPTION,
+ null,
+ ACTION,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).hasSize(1);
+ LabeledIntent labeledIntent = intents.get(0);
+ assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
+ assertThat(labeledIntent.titleWithEntity).isNull();
+ assertThat(labeledIntent.description).isEqualTo(DESCRIPTION);
+ assertThat(labeledIntent.requestCode).isEqualTo(LabeledIntent.DEFAULT_REQUEST_CODE);
+ Intent intent = labeledIntent.intent;
+ assertThat(intent.getAction()).isEqualTo(ACTION);
+ assertThat(intent.getData()).isNull();
+ assertThat(intent.getType()).isNull();
+ assertThat(intent.getFlags()).isEqualTo(0);
+ assertThat(intent.getCategories()).isNull();
+ assertThat(intent.getPackage()).isNull();
+ }
+
+ @Test
+ public void invalidTemplate_nullTemplate() {
+ RemoteActionTemplate remoteActionTemplate = null;
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_nonEmptyPackageName() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ DATA,
+ TYPE,
+ FLAG,
+ CATEGORY,
+ PACKAGE_NAME,
+ NAMED_VARIANTS,
+ REQUEST_CODE);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_emptyTitle() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ null,
+ null,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ ACTION,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_emptyDescription() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ null,
+ null,
+ ACTION,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+
+ @Test
+ public void invalidTemplate_emptyIntentAction() {
+ RemoteActionTemplate remoteActionTemplate =
+ new RemoteActionTemplate(
+ TITLE_WITHOUT_ENTITY,
+ TITLE_WITH_ENTITY,
+ DESCRIPTION,
+ DESCRIPTION_WITH_APP_NAME,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null,
+ null);
+
+ List<LabeledIntent> intents =
+ templateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
+
+ assertThat(intents).isEmpty();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/logging/ResultIdUtilsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/logging/ResultIdUtilsTest.java
new file mode 100644
index 0000000..571a97b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/logging/ResultIdUtilsTest.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.logging;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.util.Collections;
+import java.util.Locale;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class ResultIdUtilsTest {
+ private static final int MODEL_VERSION = 703;
+ private static final int HASH = 12345;
+
+ @Test
+ public void createId_customHash() {
+ String resultId =
+ ResultIdUtils.createId(MODEL_VERSION, Collections.singletonList(Locale.ENGLISH), HASH);
+
+ assertThat(resultId).isEqualTo("androidtc|en_v703|12345");
+ }
+
+ @Test
+ public void createId_selection() {
+ String resultId =
+ ResultIdUtils.createId(
+ ApplicationProvider.getApplicationContext(),
+ "text",
+ 1,
+ 2,
+ MODEL_VERSION,
+ Collections.singletonList(Locale.ENGLISH));
+
+ assertThat(resultId).matches("androidtc\\|en_v703\\|-?\\d+");
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/logging/SelectionEventConverterTest.java b/java/tests/instrumentation/src/com/android/textclassifier/logging/SelectionEventConverterTest.java
new file mode 100644
index 0000000..1c4a356
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/logging/SelectionEventConverterTest.java
@@ -0,0 +1,191 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.logging;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.view.textclassifier.SelectionEvent;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationContext;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextClassifierEvent;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.util.ArrayDeque;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.Locale;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class SelectionEventConverterTest {
+ private static final String PKG_NAME = "com.pkg";
+ private static final String WIDGET_TYPE = TextClassifier.WIDGET_TYPE_EDITTEXT;
+ private static final int START = 2;
+ private static final int SMART_START = 1;
+ private static final int SMART_END = 3;
+ private TestTextClassifier testTextClassifier;
+ private TextClassifier session;
+
+ @Before
+ public void setup() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ testTextClassifier = new TestTextClassifier();
+ textClassificationManager.setTextClassifier(testTextClassifier);
+ session = textClassificationManager.createTextClassificationSession(createEventContext());
+ }
+
+ @Test
+ public void convert_started() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, START));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent textClassifierEvent =
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textClassifierEvent.getEventContext());
+ assertThat(textClassifierEvent.getEventIndex()).isEqualTo(0);
+ assertThat(textClassifierEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_SELECTION_STARTED);
+ }
+
+ @Test
+ public void convert_smartSelection() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, START));
+ String resultId =
+ ResultIdUtils.createId(702, Collections.singletonList(Locale.ENGLISH), /*hash=*/ 12345);
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionActionEvent(
+ SMART_START,
+ SMART_END,
+ SelectionEvent.ACTION_SMART_SHARE,
+ new TextClassification.Builder()
+ .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
+ .setId(resultId)
+ .build()));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ (TextClassifierEvent.TextSelectionEvent)
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textSelectionEvent.getEventContext());
+ assertThat(textSelectionEvent.getRelativeWordStartIndex()).isEqualTo(-1);
+ assertThat(textSelectionEvent.getRelativeWordEndIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEventType()).isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
+ assertThat(textSelectionEvent.getEventIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS);
+ assertThat(textSelectionEvent.getResultId()).isEqualTo(resultId);
+ }
+
+ @Test
+ public void convert_smartShare() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_MANUAL, START));
+ String resultId =
+ ResultIdUtils.createId(702, Collections.singletonList(Locale.ENGLISH), /*hash=*/ 12345);
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionModifiedEvent(
+ SMART_START,
+ SMART_END,
+ new TextSelection.Builder(SMART_START, SMART_END)
+ .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
+ .setId(resultId)
+ .build()));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent.TextSelectionEvent textSelectionEvent =
+ (TextClassifierEvent.TextSelectionEvent)
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textSelectionEvent.getEventContext());
+ assertThat(textSelectionEvent.getRelativeSuggestedWordStartIndex()).isEqualTo(-1);
+ assertThat(textSelectionEvent.getRelativeSuggestedWordEndIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_SMART_SELECTION_MULTI);
+ assertThat(textSelectionEvent.getEventIndex()).isEqualTo(1);
+ assertThat(textSelectionEvent.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS);
+ assertThat(textSelectionEvent.getResultId()).isEqualTo(resultId);
+ }
+
+ @Test
+ public void convert_smartLinkify() {
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_LINK, START));
+ String resultId =
+ ResultIdUtils.createId(702, Collections.singletonList(Locale.ENGLISH), /*hash=*/ 12345);
+ session.onSelectionEvent(
+ SelectionEvent.createSelectionModifiedEvent(
+ SMART_START,
+ SMART_END,
+ new TextSelection.Builder(SMART_START, SMART_END)
+ .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
+ .setId(resultId)
+ .build()));
+
+ SelectionEvent interceptedEvent = testTextClassifier.popLastSelectionEvent();
+ TextClassifierEvent.TextLinkifyEvent textLinkifyEvent =
+ (TextClassifierEvent.TextLinkifyEvent)
+ SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
+
+ assertEventContext(textLinkifyEvent.getEventContext());
+ assertThat(textLinkifyEvent.getEventType())
+ .isEqualTo(TextClassifierEvent.TYPE_SMART_SELECTION_MULTI);
+ assertThat(textLinkifyEvent.getEventIndex()).isEqualTo(1);
+ assertThat(textLinkifyEvent.getEntityTypes())
+ .asList()
+ .containsExactly(TextClassifier.TYPE_ADDRESS);
+ assertThat(textLinkifyEvent.getResultId()).isEqualTo(resultId);
+ }
+
+ private static TextClassificationContext createEventContext() {
+ return new TextClassificationContext.Builder(PKG_NAME, TextClassifier.WIDGET_TYPE_EDITTEXT)
+ .build();
+ }
+
+ private static void assertEventContext(TextClassificationContext eventContext) {
+ assertThat(eventContext.getPackageName()).isEqualTo(PKG_NAME);
+ assertThat(eventContext.getWidgetType()).isEqualTo(WIDGET_TYPE);
+ }
+
+ private static class TestTextClassifier implements TextClassifier {
+ private final Deque<SelectionEvent> selectionEvents = new ArrayDeque<>();
+
+ @Override
+ public void onSelectionEvent(SelectionEvent event) {
+ selectionEvents.push(event);
+ }
+
+ SelectionEvent popLastSelectionEvent() {
+ return selectionEvents.pop();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
new file mode 100644
index 0000000..6d01a64
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.subjects;
+
+import static com.google.common.truth.Truth.assertAbout;
+
+import com.android.textclassifier.Entity;
+import com.google.common.truth.FailureMetadata;
+import com.google.common.truth.MathUtil;
+import com.google.common.truth.Subject;
+import javax.annotation.Nullable;
+
+/** Test helper for checking {@link com.android.textclassifier.Entity} results. */
+public final class EntitySubject extends Subject<EntitySubject, Entity> {
+
+ private static final float TOLERANCE = 0.0001f;
+
+ private final Entity entity;
+
+ public static EntitySubject assertThat(@Nullable Entity entity) {
+ return assertAbout(EntitySubject::new).that(entity);
+ }
+
+ private EntitySubject(FailureMetadata failureMetadata, @Nullable Entity entity) {
+ super(failureMetadata, entity);
+ this.entity = entity;
+ }
+
+ public void isMatchWithinTolerance(@Nullable Entity entity) {
+ if (!entity.getEntityType().equals(this.entity.getEntityType())) {
+ failWithActual("expected to have type", entity.getEntityType());
+ }
+ if (!MathUtil.equalWithinTolerance(entity.getScore(), this.entity.getScore(), TOLERANCE)) {
+ failWithActual("expected to have confidence score", entity.getScore());
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzerTest.java
new file mode 100644
index 0000000..91d5e3e
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzerTest.java
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import android.content.Context;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.TextClassificationConstants;
+import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
+import com.android.textclassifier.ulp.database.LanguageSignalInfo;
+import java.util.Arrays;
+import java.util.Locale;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+/** Testing {@link BasicLanguageProficiencyAnalyzer} in an in-memory database. */
+public class BasicLanguageProficiencyAnalyzerTest {
+
+ private static final String PRIMARY_SYSTEM_LANGUAGE = Locale.CHINESE.toLanguageTag();
+ private static final String SECONDARY_SYSTEM_LANGUAGE = Locale.ENGLISH.toLanguageTag();
+ private static final String NON_SYSTEM_LANGUAGE = Locale.JAPANESE.toLanguageTag();
+
+ private LanguageProfileDatabase mDatabase;
+ private BasicLanguageProficiencyAnalyzer mProficiencyAnalyzer;
+ @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ Context context = ApplicationProvider.getApplicationContext();
+ TextClassificationConstants textClassificationConstants =
+ mock(TextClassificationConstants.class);
+ mDatabase = Room.inMemoryDatabaseBuilder(context, LanguageProfileDatabase.class).build();
+ mProficiencyAnalyzer =
+ new BasicLanguageProficiencyAnalyzer(
+ textClassificationConstants, mDatabase, mSystemLanguagesProvider);
+ when(mSystemLanguagesProvider.getSystemLanguageTags())
+ .thenReturn(Arrays.asList(PRIMARY_SYSTEM_LANGUAGE, SECONDARY_SYSTEM_LANGUAGE));
+ when(textClassificationConstants.getLanguageProficiencyBootstrappingCount()).thenReturn(100);
+ }
+
+ @After
+ public void close() {
+ mDatabase.close();
+ }
+
+ @Test
+ public void canUnderstand_emptyDatabase() {
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.5f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(NON_SYSTEM_LANGUAGE)).isEqualTo(0f);
+ }
+
+ @Test
+ public void canUnderstand_validRequest() {
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ PRIMARY_SYSTEM_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100));
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ SECONDARY_SYSTEM_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 30));
+
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.4f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(NON_SYSTEM_LANGUAGE)).isEqualTo(0f);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzerTest.java
new file mode 100644
index 0000000..d554cae
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzerTest.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import android.content.Context;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.TextClassificationConstants;
+import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
+import com.android.textclassifier.ulp.database.LanguageSignalInfo;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Locale;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+/** Testing {@link KmeansLanguageProficiencyAnalyzer} in an in-memory database. */
+public class KmeansLanguageProficiencyAnalyzerTest {
+
+ private static final String PRIMARY_SYSTEM_LANGUAGE = Locale.CHINESE.toLanguageTag();
+ private static final String SECONDARY_SYSTEM_LANGUAGE = Locale.ENGLISH.toLanguageTag();
+ private static final String NORMAL_LANGUAGE = Locale.JAPANESE.toLanguageTag();
+
+ private LanguageProfileDatabase mDatabase;
+ private KmeansLanguageProficiencyAnalyzer mProficiencyAnalyzer;
+ @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ Context context = ApplicationProvider.getApplicationContext();
+ TextClassificationConstants textClassificationConstants =
+ mock(TextClassificationConstants.class);
+ mDatabase = Room.inMemoryDatabaseBuilder(context, LanguageProfileDatabase.class).build();
+ mProficiencyAnalyzer =
+ new KmeansLanguageProficiencyAnalyzer(
+ textClassificationConstants, mDatabase, mSystemLanguagesProvider);
+ when(mSystemLanguagesProvider.getSystemLanguageTags())
+ .thenReturn(Arrays.asList(PRIMARY_SYSTEM_LANGUAGE, SECONDARY_SYSTEM_LANGUAGE));
+ when(textClassificationConstants.getLanguageProficiencyBootstrappingCount()).thenReturn(100);
+ }
+
+ @After
+ public void close() {
+ mDatabase.close();
+ }
+
+ @Test
+ public void canUnderstand_emptyDatabase() {
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.5f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
+ }
+
+ @Test
+ public void canUnderstand_oneLanguage() {
+ when(mSystemLanguagesProvider.getSystemLanguageTags())
+ .thenReturn(Collections.singletonList(PRIMARY_SYSTEM_LANGUAGE));
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ PRIMARY_SYSTEM_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1));
+
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
+ }
+
+ @Test
+ public void canUnderstand_twoLanguages() {
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ PRIMARY_SYSTEM_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100));
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ SECONDARY_SYSTEM_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 50));
+
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.5f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
+ }
+
+ @Test
+ public void canUnderstand_threeLanguages() {
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ PRIMARY_SYSTEM_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100));
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ SECONDARY_SYSTEM_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 75));
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(
+ new LanguageSignalInfo(
+ NORMAL_LANGUAGE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 2));
+
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluatorTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluatorTest.java
new file mode 100644
index 0000000..74ec27c
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluatorTest.java
@@ -0,0 +1,160 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.google.common.collect.ImmutableSet;
+import java.util.Arrays;
+import java.util.Set;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+import org.mockito.stubbing.Answer;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class LanguageProficiencyEvaluatorTest {
+ private static final float EPSILON = 0.01f;
+ private LanguageProficiencyEvaluator mLanguageProficiencyEvaluator;
+
+ @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
+
+ private static final String SYSTEM_LANGUAGE_EN = "en";
+ private static final String SYSTEM_LANGUAGE_ZH = "zh";
+ private static final String NORMAL_LANGUAGE_JP = "jp";
+ private static final String NORMAL_LANGUAGE_FR = "fr";
+ private static final String NORMAL_LANGUAGE_PL = "pl";
+ private static final Set<String> EVALUATION_LANGUAGES =
+ ImmutableSet.of(
+ SYSTEM_LANGUAGE_EN,
+ SYSTEM_LANGUAGE_ZH,
+ NORMAL_LANGUAGE_JP,
+ NORMAL_LANGUAGE_FR,
+ NORMAL_LANGUAGE_PL);
+
+ @Mock private LanguageProficiencyAnalyzer mLanguageProficiencyAnalyzer;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ when(mSystemLanguagesProvider.getSystemLanguageTags())
+ .thenReturn(Arrays.asList(SYSTEM_LANGUAGE_EN, SYSTEM_LANGUAGE_ZH));
+ mLanguageProficiencyEvaluator = new LanguageProficiencyEvaluator(mSystemLanguagesProvider);
+ }
+
+ @Test
+ public void evaluate_allCorrect() {
+ when(mLanguageProficiencyAnalyzer.canUnderstand(Mockito.anyString()))
+ .thenAnswer(
+ (Answer<Float>)
+ invocation -> {
+ String languageTag = invocation.getArgument(0);
+ if (languageTag.equals(SYSTEM_LANGUAGE_EN)
+ || languageTag.equals(SYSTEM_LANGUAGE_ZH)) {
+ return 1f;
+ }
+ return 0f;
+ });
+
+ LanguageProficiencyEvaluator.EvaluationResult evaluationResult =
+ mLanguageProficiencyEvaluator.evaluate(mLanguageProficiencyAnalyzer, EVALUATION_LANGUAGES);
+
+ assertThat(evaluationResult.truePositive).isEqualTo(2);
+ assertThat(evaluationResult.trueNegative).isEqualTo(3);
+ assertThat(evaluationResult.falsePositive).isEqualTo(0);
+ assertThat(evaluationResult.falseNegative).isEqualTo(0);
+ assertThat(evaluationResult.computePrecisionOfPositiveClass()).isWithin(EPSILON).of(1f);
+ assertThat(evaluationResult.computePrecisionOfNegativeClass()).isWithin(EPSILON).of(1f);
+ assertThat(evaluationResult.computeRecallOfPositiveClass()).isWithin(EPSILON).of(1f);
+ assertThat(evaluationResult.computeRecallOfNegativeClass()).isWithin(EPSILON).of(1f);
+ assertThat(evaluationResult.computeF1ScoreOfPositiveClass()).isWithin(EPSILON).of(1f);
+ assertThat(evaluationResult.computeF1ScoreOfNegativeClass()).isWithin(EPSILON).of(1f);
+ }
+
+ @Test
+ public void evaluate_allWrong() {
+ when(mLanguageProficiencyAnalyzer.canUnderstand(Mockito.anyString()))
+ .thenAnswer(
+ (Answer<Float>)
+ invocation -> {
+ String languageTag = invocation.getArgument(0);
+ if (languageTag.equals(SYSTEM_LANGUAGE_EN)
+ || languageTag.equals(SYSTEM_LANGUAGE_ZH)) {
+ return 0f;
+ }
+ return 1f;
+ });
+
+ LanguageProficiencyEvaluator.EvaluationResult evaluationResult =
+ mLanguageProficiencyEvaluator.evaluate(mLanguageProficiencyAnalyzer, EVALUATION_LANGUAGES);
+
+ assertThat(evaluationResult.truePositive).isEqualTo(0);
+ assertThat(evaluationResult.trueNegative).isEqualTo(0);
+ assertThat(evaluationResult.falsePositive).isEqualTo(3);
+ assertThat(evaluationResult.falseNegative).isEqualTo(2);
+ assertThat(evaluationResult.computePrecisionOfPositiveClass()).isWithin(EPSILON).of(0f);
+ assertThat(evaluationResult.computePrecisionOfNegativeClass()).isWithin(EPSILON).of(0f);
+ assertThat(evaluationResult.computeRecallOfPositiveClass()).isWithin(EPSILON).of(0f);
+ assertThat(evaluationResult.computeRecallOfNegativeClass()).isWithin(EPSILON).of(0f);
+ assertThat(evaluationResult.computeF1ScoreOfPositiveClass()).isWithin(EPSILON).of(0f);
+ assertThat(evaluationResult.computeF1ScoreOfNegativeClass()).isWithin(EPSILON).of(0f);
+ }
+
+ @Test
+ public void evaluate_mixed() {
+ when(mLanguageProficiencyAnalyzer.canUnderstand(Mockito.anyString()))
+ .thenAnswer(
+ (Answer<Float>)
+ invocation -> {
+ String languageTag = invocation.getArgument(0);
+ switch (languageTag) {
+ case SYSTEM_LANGUAGE_EN:
+ return 1f;
+ case SYSTEM_LANGUAGE_ZH:
+ return 0f;
+ case NORMAL_LANGUAGE_FR:
+ return 0f;
+ case NORMAL_LANGUAGE_JP:
+ return 0f;
+ case NORMAL_LANGUAGE_PL:
+ return 1f;
+ }
+ throw new IllegalArgumentException("unexpected language: " + languageTag);
+ });
+
+ LanguageProficiencyEvaluator.EvaluationResult evaluationResult =
+ mLanguageProficiencyEvaluator.evaluate(mLanguageProficiencyAnalyzer, EVALUATION_LANGUAGES);
+
+ assertThat(evaluationResult.truePositive).isEqualTo(1);
+ assertThat(evaluationResult.trueNegative).isEqualTo(2);
+ assertThat(evaluationResult.falsePositive).isEqualTo(1);
+ assertThat(evaluationResult.falseNegative).isEqualTo(1);
+ assertThat(evaluationResult.computePrecisionOfPositiveClass()).isWithin(EPSILON).of(0.5f);
+ assertThat(evaluationResult.computePrecisionOfNegativeClass()).isWithin(EPSILON).of(0.66f);
+ assertThat(evaluationResult.computeRecallOfPositiveClass()).isWithin(EPSILON).of(0.5f);
+ assertThat(evaluationResult.computeRecallOfNegativeClass()).isWithin(EPSILON).of(0.66f);
+ assertThat(evaluationResult.computeF1ScoreOfPositiveClass()).isWithin(EPSILON).of(0.5f);
+ assertThat(evaluationResult.computeF1ScoreOfNegativeClass()).isWithin(EPSILON).of(0.66f);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProfileAnalyzerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProfileAnalyzerTest.java
new file mode 100644
index 0000000..46153db
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProfileAnalyzerTest.java
@@ -0,0 +1,141 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import android.content.Context;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.Entity;
+import com.android.textclassifier.TextClassificationConstants;
+import com.android.textclassifier.subjects.EntitySubject;
+import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
+import com.android.textclassifier.ulp.database.LanguageSignalInfo;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+/** Testing {@link LanguageProfileAnalyzer} in an inMemoryDatabase. */
+public class LanguageProfileAnalyzerTest {
+
+ private static final String SYSTEM_LANGUAGE_CODE = "en";
+ private static final String LOCATION_LANGUAGE_CODE = "jp";
+ private static final String NORMAL_LANGUAGE_CODE = "pl";
+
+ private LanguageProfileDatabase mDatabase;
+ private LanguageProfileAnalyzer mLanguageProfileAnalyzer;
+ @Mock private LocationSignalProvider mLocationSignalProvider;
+ @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
+ @Mock private LanguageProficiencyAnalyzer mLanguageProficiencyAnalyzer;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ Context mContext = ApplicationProvider.getApplicationContext();
+ mDatabase = Room.inMemoryDatabaseBuilder(mContext, LanguageProfileDatabase.class).build();
+ when(mLocationSignalProvider.detectLanguageTag()).thenReturn(LOCATION_LANGUAGE_CODE);
+ when(mSystemLanguagesProvider.getSystemLanguageTags())
+ .thenReturn(Collections.singletonList(SYSTEM_LANGUAGE_CODE));
+ when(mLanguageProficiencyAnalyzer.canUnderstand(anyString())).thenReturn(1.0f);
+ TextClassificationConstants customTextClassificationConstants =
+ mock(TextClassificationConstants.class);
+ when(customTextClassificationConstants.getFrequentLanguagesBootstrappingCount())
+ .thenReturn(100);
+ mLanguageProfileAnalyzer =
+ new LanguageProfileAnalyzer(
+ mContext,
+ customTextClassificationConstants,
+ mDatabase,
+ mLanguageProficiencyAnalyzer,
+ mLocationSignalProvider,
+ mSystemLanguagesProvider);
+ }
+
+ @After
+ public void close() {
+ mDatabase.close();
+ }
+
+ @Test
+ public void getFrequentLanguages_emptyDatabase() {
+ List<Entity> frequentLanguages =
+ mLanguageProfileAnalyzer.getFrequentLanguages(LanguageSignalInfo.CLASSIFY_TEXT);
+
+ assertThat(frequentLanguages).hasSize(2);
+ EntitySubject.assertThat(frequentLanguages.get(0))
+ .isMatchWithinTolerance(new Entity(SYSTEM_LANGUAGE_CODE, 1.0f));
+ EntitySubject.assertThat(frequentLanguages.get(1))
+ .isMatchWithinTolerance(new Entity(LOCATION_LANGUAGE_CODE, 1.0f));
+ }
+
+ @Test
+ public void getFrequentLanguages_mixedSignal() {
+ insertSignal(NORMAL_LANGUAGE_CODE, LanguageSignalInfo.CLASSIFY_TEXT, 50);
+ insertSignal(SYSTEM_LANGUAGE_CODE, LanguageSignalInfo.CLASSIFY_TEXT, 100);
+ // Unrelated signals.
+ insertSignal(NORMAL_LANGUAGE_CODE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100);
+ insertSignal(SYSTEM_LANGUAGE_CODE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100);
+ insertSignal(LOCATION_LANGUAGE_CODE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100);
+
+ List<Entity> frequentLanguages =
+ mLanguageProfileAnalyzer.getFrequentLanguages(LanguageSignalInfo.CLASSIFY_TEXT);
+
+ assertThat(frequentLanguages).hasSize(3);
+ EntitySubject.assertThat(frequentLanguages.get(0))
+ .isMatchWithinTolerance(new Entity(SYSTEM_LANGUAGE_CODE, 1.0f));
+ EntitySubject.assertThat(frequentLanguages.get(1))
+ .isMatchWithinTolerance(new Entity(LOCATION_LANGUAGE_CODE, 0.5f));
+ EntitySubject.assertThat(frequentLanguages.get(2))
+ .isMatchWithinTolerance(new Entity(NORMAL_LANGUAGE_CODE, 0.25f));
+ }
+
+ @Test
+ public void getFrequentLanguages_bothSystemLanguageAndLocationLanguage() {
+ when(mLocationSignalProvider.detectLanguageTag()).thenReturn("en");
+ when(mSystemLanguagesProvider.getSystemLanguageTags()).thenReturn(Arrays.asList("en", "jp"));
+
+ List<Entity> frequentLanguages =
+ mLanguageProfileAnalyzer.getFrequentLanguages(LanguageSignalInfo.CLASSIFY_TEXT);
+
+ assertThat(frequentLanguages).hasSize(2);
+ EntitySubject.assertThat(frequentLanguages.get(0))
+ .isMatchWithinTolerance(new Entity("en", 1.0f));
+ EntitySubject.assertThat(frequentLanguages.get(1))
+ .isMatchWithinTolerance(new Entity("jp", 0.5f));
+ }
+
+ private void insertSignal(String languageTag, @LanguageSignalInfo.Source int source, int count) {
+ mDatabase
+ .languageInfoDao()
+ .insertLanguageInfo(new LanguageSignalInfo(languageTag, source, count));
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProfileUpdaterTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProfileUpdaterTest.java
new file mode 100644
index 0000000..02958b6
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LanguageProfileUpdaterTest.java
@@ -0,0 +1,213 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.Person;
+import android.content.Context;
+import android.os.Bundle;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
+import com.android.textclassifier.ulp.database.LanguageSignalInfo;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.time.ZoneId;
+import java.time.ZonedDateTime;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executors;
+import java.util.function.Function;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+/** Testing {@link LanguageProfileUpdater} in an in-memory database. */
+public class LanguageProfileUpdaterTest {
+
+ private static final String NOTIFICATION_KEY = "test_notification";
+ private static final String LOCALE_TAG_US = Locale.US.toLanguageTag();
+ private static final String LOCALE_TAG_CHINA = Locale.CHINA.toLanguageTag();
+ private static final String TEXT_ONE = "hello world";
+ private static final String TEXT_TWO = "你好!";
+ private static final Function<CharSequence, List<String>> LANGUAGE_DETECTOR_US =
+ charSequence -> ImmutableList.of(LOCALE_TAG_US);
+ private static final Function<CharSequence, List<String>> LANGUAGE_DETECTOR_CHINA =
+ charSequence -> ImmutableList.of(LOCALE_TAG_CHINA);
+ private static final Person PERSON = new Person.Builder().build();
+ private static final ZonedDateTime TIME_ONE =
+ ZonedDateTime.of(2019, 7, 21, 12, 12, 12, 12, ZoneId.systemDefault());
+ private static final ZonedDateTime TIME_TWO =
+ ZonedDateTime.of(2019, 7, 21, 12, 20, 20, 12, ZoneId.systemDefault());
+ private static final ConversationActions.Message MSG_ONE =
+ new ConversationActions.Message.Builder(PERSON)
+ .setReferenceTime(TIME_ONE)
+ .setText(TEXT_ONE)
+ .setExtras(new Bundle())
+ .build();
+ private static final ConversationActions.Message MSG_TWO =
+ new ConversationActions.Message.Builder(PERSON)
+ .setReferenceTime(TIME_TWO)
+ .setText("where are you?")
+ .setExtras(new Bundle())
+ .build();
+ private static final ConversationActions.Message MSG_THREE =
+ new ConversationActions.Message.Builder(PERSON)
+ .setReferenceTime(TIME_TWO)
+ .setText(TEXT_TWO)
+ .setExtras(new Bundle())
+ .build();
+ private static final ConversationActions.Request CONVERSATION_ACTION_REQUEST_ONE =
+ new ConversationActions.Request.Builder(Arrays.asList(MSG_ONE)).build();
+ private static final ConversationActions.Request CONVERSATION_ACTION_REQUEST_TWO =
+ new ConversationActions.Request.Builder(Arrays.asList(MSG_TWO)).build();
+ private static final TextClassification.Request TEXT_CLASSIFICATION_REQUEST_ONE =
+ new TextClassification.Request.Builder(TEXT_ONE, 0, 2).build();
+ private static final LanguageSignalInfo US_INFO_ONE_FOR_CONVERSATION_ACTION_ONE =
+ new LanguageSignalInfo(LOCALE_TAG_US, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1);
+ private static final LanguageSignalInfo US_INFO_ONE_FOR_CONVERSATION_ACTION_TWO =
+ new LanguageSignalInfo(LOCALE_TAG_US, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 2);
+ private static final LanguageSignalInfo US_INFO_ONE_FOR_CLASSIFY_TEXT =
+ new LanguageSignalInfo(LOCALE_TAG_US, LanguageSignalInfo.CLASSIFY_TEXT, 1);
+
+ private LanguageProfileUpdater mLanguageProfileUpdater;
+ private LanguageProfileDatabase mDatabase;
+
+ @Before
+ public void setup() {
+ Context mContext = ApplicationProvider.getApplicationContext();
+ ListeningExecutorService mExecutorService =
+ MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor());
+ mDatabase = Room.inMemoryDatabaseBuilder(mContext, LanguageProfileDatabase.class).build();
+ mLanguageProfileUpdater = new LanguageProfileUpdater(mExecutorService, mDatabase);
+ }
+
+ @After
+ public void close() {
+ mDatabase.close();
+ }
+
+ @Test
+ public void updateFromConversationActionsAsync_oneMessage()
+ throws ExecutionException, InterruptedException {
+ mLanguageProfileUpdater
+ .updateFromConversationActionsAsync(CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
+ .get();
+ List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
+
+ assertThat(infos).hasSize(1);
+ LanguageSignalInfo info = infos.get(0);
+ assertThat(info).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_ONE);
+ }
+
+ /** Notification keys for these two messages are DEFAULT_NOTIFICATION_KEY */
+ @Test
+ public void updateFromConversationActionsAsync_twoMessagesInSameNotificationWithSameLanguage()
+ throws ExecutionException, InterruptedException {
+ mLanguageProfileUpdater
+ .updateFromConversationActionsAsync(CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
+ .get();
+ mLanguageProfileUpdater
+ .updateFromConversationActionsAsync(CONVERSATION_ACTION_REQUEST_TWO, LANGUAGE_DETECTOR_US)
+ .get();
+ List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
+
+ assertThat(infos).hasSize(1);
+ LanguageSignalInfo info = infos.get(0);
+ assertThat(info).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_TWO);
+ }
+
+ @Test
+ public void updateFromConversationActionsAsync_twoMessagesInDifferentNotifications()
+ throws ExecutionException, InterruptedException {
+ mLanguageProfileUpdater
+ .updateFromConversationActionsAsync(CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
+ .get();
+ Bundle extra = new Bundle();
+ extra.putString(LanguageProfileUpdater.NOTIFICATION_KEY, NOTIFICATION_KEY);
+ ConversationActions.Request newRequest =
+ new ConversationActions.Request.Builder(Arrays.asList(MSG_TWO)).setExtras(extra).build();
+ mLanguageProfileUpdater
+ .updateFromConversationActionsAsync(newRequest, LANGUAGE_DETECTOR_US)
+ .get();
+ List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
+
+ assertThat(infos).hasSize(1);
+ LanguageSignalInfo info = infos.get(0);
+ assertThat(info).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_TWO);
+ }
+
+ @Test
+ public void updateFromConversationActionsAsync_twoMessagesInDifferentLanguage()
+ throws ExecutionException, InterruptedException {
+ mLanguageProfileUpdater
+ .updateFromConversationActionsAsync(CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
+ .get();
+ ConversationActions.Request newRequest =
+ new ConversationActions.Request.Builder(Arrays.asList(MSG_THREE)).build();
+ mLanguageProfileUpdater
+ .updateFromConversationActionsAsync(newRequest, LANGUAGE_DETECTOR_CHINA)
+ .get();
+ List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
+
+ assertThat(infos).hasSize(2);
+ LanguageSignalInfo infoOne = infos.get(0);
+ LanguageSignalInfo infoTwo = infos.get(1);
+ assertThat(infoOne).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_ONE);
+ assertThat(infoTwo)
+ .isEqualTo(
+ new LanguageSignalInfo(
+ LOCALE_TAG_CHINA, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1));
+ }
+
+ @Test
+ public void updateFromClassifyTextAsync_classifyText()
+ throws ExecutionException, InterruptedException {
+ mLanguageProfileUpdater.updateFromClassifyTextAsync(ImmutableList.of(LOCALE_TAG_US)).get();
+ List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
+
+ assertThat(infos).hasSize(1);
+ LanguageSignalInfo info = infos.get(0);
+ assertThat(info).isEqualTo(US_INFO_ONE_FOR_CLASSIFY_TEXT);
+ }
+
+ @Test
+ public void updateFromClassifyTextAsync_classifyTextTwice()
+ throws ExecutionException, InterruptedException {
+ mLanguageProfileUpdater.updateFromClassifyTextAsync(ImmutableList.of(LOCALE_TAG_US)).get();
+ mLanguageProfileUpdater.updateFromClassifyTextAsync(ImmutableList.of(LOCALE_TAG_CHINA)).get();
+
+ List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
+ assertThat(infos).hasSize(2);
+ LanguageSignalInfo infoOne = infos.get(0);
+ LanguageSignalInfo infoTwo = infos.get(1);
+ assertThat(infoOne).isEqualTo(US_INFO_ONE_FOR_CLASSIFY_TEXT);
+ assertThat(infoTwo)
+ .isEqualTo(new LanguageSignalInfo(LOCALE_TAG_CHINA, LanguageSignalInfo.CLASSIFY_TEXT, 1));
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/LocationSignalProviderTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LocationSignalProviderTest.java
new file mode 100644
index 0000000..6ceacab
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/LocationSignalProviderTest.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import android.location.Address;
+import android.location.Geocoder;
+import android.location.Location;
+import android.location.LocationManager;
+import android.telephony.TelephonyManager;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.Locale;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.Mockito;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class LocationSignalProviderTest {
+ @Mock private LocationManager mLocationManager;
+ @Mock private TelephonyManager mTelephonyManager;
+ @Mock private LocationSignalProvider mLocationSignalProvider;
+ @Mock private Geocoder mGeocoder;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ mLocationSignalProvider =
+ new LocationSignalProvider(mLocationManager, mTelephonyManager, mGeocoder);
+ }
+
+ @Test
+ public void detectLanguageTag_useTelephony() {
+ when(mTelephonyManager.getNetworkCountryIso()).thenReturn(Locale.UK.getCountry());
+
+ assertThat(mLocationSignalProvider.detectLanguageTag()).isEqualTo("en");
+ }
+
+ @Test
+ public void detectLanguageTag_useLocation() throws IOException {
+ when(mTelephonyManager.getNetworkCountryIso()).thenReturn(null);
+ Location location = new Location(LocationManager.PASSIVE_PROVIDER);
+ when(mLocationManager.getLastKnownLocation(LocationManager.PASSIVE_PROVIDER))
+ .thenReturn(location);
+ Address address = new Address(Locale.FRANCE);
+ address.setCountryCode(Locale.FRANCE.getCountry());
+ when(mGeocoder.getFromLocation(Mockito.anyDouble(), Mockito.anyDouble(), Mockito.anyInt()))
+ .thenReturn(Collections.singletonList(address));
+
+ assertThat(mLocationSignalProvider.detectLanguageTag()).isEqualTo("fr");
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzerTest.java
new file mode 100644
index 0000000..dda6600
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzerTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
+
+import android.content.Context;
+import android.content.SharedPreferences;
+import android.view.textclassifier.TextClassifierEvent;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.util.Arrays;
+import java.util.Locale;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+/** Testing {@link ReinforcementLanguageProficiencyAnalyzer} using Mockito. */
+public class ReinforcementLanguageProficiencyAnalyzerTest {
+
+ private static final String PRIMARY_SYSTEM_LANGUAGE = Locale.CHINESE.toLanguageTag();
+ private static final String SECONDARY_SYSTEM_LANGUAGE = Locale.ENGLISH.toLanguageTag();
+ private static final String NON_SYSTEM_LANGUAGE = Locale.JAPANESE.toLanguageTag();
+ private ReinforcementLanguageProficiencyAnalyzer mProficiencyAnalyzer;
+ @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
+ private SharedPreferences mSharedPreferences;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+ Context context = ApplicationProvider.getApplicationContext();
+ mSharedPreferences = context.getSharedPreferences("test-preferences", Context.MODE_PRIVATE);
+ when(mSystemLanguagesProvider.getSystemLanguageTags())
+ .thenReturn(Arrays.asList(PRIMARY_SYSTEM_LANGUAGE, SECONDARY_SYSTEM_LANGUAGE));
+ mProficiencyAnalyzer =
+ new ReinforcementLanguageProficiencyAnalyzer(mSystemLanguagesProvider, mSharedPreferences);
+ }
+
+ @After
+ public void teardown() {
+ mSharedPreferences.edit().clear().apply();
+ }
+
+ @Test
+ public void canUnderstand_defaultValue() {
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1.0f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(1.0f);
+ assertThat(mProficiencyAnalyzer.canUnderstand(NON_SYSTEM_LANGUAGE)).isEqualTo(0f);
+ }
+
+ @Test
+ public void canUnderstand_enoughFeedback() {
+ sendEvent(TextClassifierEvent.TYPE_ACTIONS_SHOWN, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 50);
+ sendEvent(TextClassifierEvent.TYPE_SMART_ACTION, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 40);
+
+ assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(0.8f);
+ }
+
+ @Test
+ public void shouldShowTranslation_defaultValue() {
+ assertThat(mProficiencyAnalyzer.shouldShowTranslation(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(true);
+ assertThat(mProficiencyAnalyzer.shouldShowTranslation(SECONDARY_SYSTEM_LANGUAGE))
+ .isEqualTo(true);
+ assertThat(mProficiencyAnalyzer.shouldShowTranslation(NON_SYSTEM_LANGUAGE)).isEqualTo(true);
+ }
+
+ @Test
+ public void shouldShowTranslation_enoughFeedback_true() {
+ sendEvent(TextClassifierEvent.TYPE_ACTIONS_SHOWN, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 1000);
+ sendEvent(TextClassifierEvent.TYPE_SMART_ACTION, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 200);
+
+ assertThat(mProficiencyAnalyzer.shouldShowTranslation(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(true);
+ }
+
+ @Test
+ public void shouldShowTranslation_enoughFeedback_false() {
+ sendEvent(TextClassifierEvent.TYPE_ACTIONS_SHOWN, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 1000);
+ sendEvent(TextClassifierEvent.TYPE_SMART_ACTION, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 1000);
+
+ assertThat(mProficiencyAnalyzer.shouldShowTranslation(PRIMARY_SYSTEM_LANGUAGE))
+ .isEqualTo(false);
+ }
+
+ private void sendEvent(int type, String languageTag, int times) {
+ TextClassifierEvent.LanguageDetectionEvent event =
+ new TextClassifierEvent.LanguageDetectionEvent.Builder(type)
+ .setEntityTypes(languageTag)
+ .setActionIndices(0)
+ .build();
+ for (int i = 0; i < times; i++) {
+ mProficiencyAnalyzer.onTextClassifierEvent(event);
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ulp/SystemLanguagesProviderTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ulp/SystemLanguagesProviderTest.java
new file mode 100644
index 0000000..45e8608
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ulp/SystemLanguagesProviderTest.java
@@ -0,0 +1,60 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.ulp;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.res.Resources;
+import android.os.LocaleList;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import java.util.List;
+import java.util.Locale;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class SystemLanguagesProviderTest {
+ private SystemLanguagesProvider mSystemLanguagesProvider;
+
+ @Before
+ public void setup() {
+ mSystemLanguagesProvider = new SystemLanguagesProvider();
+ }
+
+ @Test
+ public void getSystemLanguageTags_singleLanguages() {
+ Resources.getSystem().getConfiguration().setLocales(new LocaleList(Locale.FRANCE));
+
+ List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
+
+ assertThat(systemLanguageTags).containsExactly("fr");
+ }
+
+ @Test
+ public void getSystemLanguageTags_multipleLanguages() {
+ Resources.getSystem()
+ .getConfiguration()
+ .setLocales(new LocaleList(Locale.FRANCE, Locale.ENGLISH));
+
+ List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
+
+ assertThat(systemLanguageTags).containsExactly("fr", "en");
+ }
+}
diff --git a/java/tests/unittests/AndroidManifest.xml b/java/tests/unittests/AndroidManifest.xml
deleted file mode 100644
index 7864815..0000000
--- a/java/tests/unittests/AndroidManifest.xml
+++ /dev/null
@@ -1,29 +0,0 @@
-<?xml version="1.0" encoding="utf-8"?>
-<!--
- ~ Copyright (C) 2019 The Android Open Source Project
- ~
- ~ Licensed under the Apache License, Version 2.0 (the "License");
- ~ you may not use this file except in compliance with the License.
- ~ You may obtain a copy of the License at
- ~
- ~ http://www.apache.org/licenses/LICENSE-2.0
- ~
- ~ Unless required by applicable law or agreed to in writing, software
- ~ distributed under the License is distributed on an "AS IS" BASIS,
- ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- ~ See the License for the specific language governing permissions and
- ~ limitations under the License
- -->
-
-<manifest xmlns:android="http://schemas.android.com/apk/res/android"
- package="com.android.textclassifier.tests">
- <uses-permission android:name="android.permission.WRITE_DEVICE_CONFIG" />
- <application android:debuggable="true">
- <uses-library android:name="android.test.runner"/>
- </application>
-
- <instrumentation android:name="androidx.test.runner.AndroidJUnitRunner"
- android:targetPackage="com.android.textclassifier.tests"
- android:label="Tests for TextClassifierService"/>
-
-</manifest>
\ No newline at end of file
diff --git a/java/tests/unittests/AndroidTest.xml b/java/tests/unittests/AndroidTest.xml
deleted file mode 100644
index d55e8c5..0000000
--- a/java/tests/unittests/AndroidTest.xml
+++ /dev/null
@@ -1,27 +0,0 @@
-<?xml version="1.0" encoding="utf-8"?>
-<!--
- ~ Copyright (C) 2019 The Android Open Source Project
- ~
- ~ Licensed under the Apache License, Version 2.0 (the "License");
- ~ you may not use this file except in compliance with the License.
- ~ You may obtain a copy of the License at
- ~
- ~ http://www.apache.org/licenses/LICENSE-2.0
- ~
- ~ Unless required by applicable law or agreed to in writing, software
- ~ distributed under the License is distributed on an "AS IS" BASIS,
- ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- ~ See the License for the specific language governing permissions and
- ~ limitations under the License
- -->
-<configuration description="Runs Tests for TextClassifierServiceTest.">
- <target_preparer class="com.android.tradefed.targetprep.TestAppInstallSetup">
- <option name="test-file-name" value="TextClassifierServiceTest.apk"/>
- </target_preparer>
- <!--<option name="test-suite-tag" value="apct"/>-->
- <option name="test-tag" value="TextClassifierServiceTest"/>
- <test class="com.android.tradefed.testtype.AndroidJUnitTest">
- <option name="package" value="com.android.textclassifier.tests"/>
- <option name="runner" value="androidx.test.runner.AndroidJUnitRunner"/>
- </test>
-</configuration>
\ No newline at end of file
diff --git a/java/tests/unittests/src/com/android/textclassifier/ActionsModelParamsSupplierTest.java b/java/tests/unittests/src/com/android/textclassifier/ActionsModelParamsSupplierTest.java
deleted file mode 100644
index a8ae8aa..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ActionsModelParamsSupplierTest.java
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-package com.android.textclassifier;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import java.io.File;
-import java.util.Collections;
-import java.util.Locale;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class ActionsModelParamsSupplierTest {
-
- @Test
- public void getSerializedPreconditions_validActionsModelParams() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- new File("/model/file"),
- 200 /* version */,
- Collections.singletonList(Locale.forLanguageTag("en")),
- "en",
- false);
- byte[] serializedPreconditions = new byte[] {0x12, 0x24, 0x36};
- ActionsModelParamsSupplier.ActionsModelParams params =
- new ActionsModelParamsSupplier.ActionsModelParams(
- 200 /* version */, "en", serializedPreconditions);
-
- byte[] actual = params.getSerializedPreconditions(modelFile);
-
- assertThat(actual).isEqualTo(serializedPreconditions);
- }
-
- @Test
- public void getSerializedPreconditions_invalidVersion() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- new File("/model/file"),
- 201 /* version */,
- Collections.singletonList(Locale.forLanguageTag("en")),
- "en",
- false);
- byte[] serializedPreconditions = new byte[] {0x12, 0x24, 0x36};
- ActionsModelParamsSupplier.ActionsModelParams params =
- new ActionsModelParamsSupplier.ActionsModelParams(
- 200 /* version */, "en", serializedPreconditions);
-
- byte[] actual = params.getSerializedPreconditions(modelFile);
-
- assertThat(actual).isNull();
- }
-
- @Test
- public void getSerializedPreconditions_invalidLocales() {
- final String LANGUAGE_TAG = "zh";
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- new File("/model/file"),
- 200 /* version */,
- Collections.singletonList(Locale.forLanguageTag(LANGUAGE_TAG)),
- LANGUAGE_TAG,
- false);
- byte[] serializedPreconditions = new byte[] {0x12, 0x24, 0x36};
- ActionsModelParamsSupplier.ActionsModelParams params =
- new ActionsModelParamsSupplier.ActionsModelParams(
- 200 /* version */, "en", serializedPreconditions);
-
- byte[] actual = params.getSerializedPreconditions(modelFile);
-
- assertThat(actual).isNull();
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java b/java/tests/unittests/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
deleted file mode 100644
index c25fce8..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ActionsSuggestionsHelperTest.java
+++ /dev/null
@@ -1,314 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier;
-
-import static android.view.textclassifier.ConversationActions.Message.PERSON_USER_OTHERS;
-import static android.view.textclassifier.ConversationActions.Message.PERSON_USER_SELF;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.app.PendingIntent;
-import android.app.Person;
-import android.app.RemoteAction;
-import android.content.ComponentName;
-import android.content.Intent;
-import android.graphics.drawable.Icon;
-import android.net.Uri;
-import android.os.Bundle;
-import android.view.textclassifier.ConversationAction;
-import android.view.textclassifier.ConversationActions;
-
-import androidx.test.InstrumentationRegistry;
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import com.android.textclassifier.intent.LabeledIntent;
-import com.android.textclassifier.intent.TemplateIntentFactory;
-
-import com.google.android.textclassifier.ActionsSuggestionsModel;
-import com.google.android.textclassifier.RemoteActionTemplate;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import java.time.Instant;
-import java.time.ZoneId;
-import java.time.ZonedDateTime;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Locale;
-import java.util.function.Function;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class ActionsSuggestionsHelperTest {
- private static final String LOCALE_TAG = Locale.US.toLanguageTag();
- private static final Function<CharSequence, List<String>> LANGUAGE_DETECTOR =
- charSequence -> Collections.singletonList(LOCALE_TAG);
-
- @Test
- public void testToNativeMessages_emptyInput() {
- ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- Collections.emptyList(), LANGUAGE_DETECTOR);
-
- assertThat(conversationMessages).isEmpty();
- }
-
- @Test
- public void testToNativeMessages_noTextMessages() {
- ConversationActions.Message messageWithoutText =
- new ConversationActions.Message.Builder(PERSON_USER_OTHERS).build();
-
- ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- Collections.singletonList(messageWithoutText), LANGUAGE_DETECTOR);
-
- assertThat(conversationMessages).isEmpty();
- }
-
- @Test
- public void testToNativeMessages_userIdEncoding() {
- Person userA = new Person.Builder().setName("userA").build();
- Person userB = new Person.Builder().setName("userB").build();
-
- ConversationActions.Message firstMessage =
- new ConversationActions.Message.Builder(userB).setText("first").build();
- ConversationActions.Message secondMessage =
- new ConversationActions.Message.Builder(userA).setText("second").build();
- ConversationActions.Message thirdMessage =
- new ConversationActions.Message.Builder(PERSON_USER_SELF).setText("third").build();
- ConversationActions.Message fourthMessage =
- new ConversationActions.Message.Builder(userA).setText("fourth").build();
-
- ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- Arrays.asList(firstMessage, secondMessage, thirdMessage, fourthMessage),
- LANGUAGE_DETECTOR);
-
- assertThat(conversationMessages).hasLength(4);
- assertNativeMessage(conversationMessages[0], firstMessage.getText(), 2, 0);
- assertNativeMessage(conversationMessages[1], secondMessage.getText(), 1, 0);
- assertNativeMessage(conversationMessages[2], thirdMessage.getText(), 0, 0);
- assertNativeMessage(conversationMessages[3], fourthMessage.getText(), 1, 0);
- }
-
- @Test
- public void testToNativeMessages_referenceTime() {
- ConversationActions.Message firstMessage =
- new ConversationActions.Message.Builder(PERSON_USER_OTHERS)
- .setText("first")
- .setReferenceTime(createZonedDateTimeFromMsUtc(1000))
- .build();
- ConversationActions.Message secondMessage =
- new ConversationActions.Message.Builder(PERSON_USER_OTHERS)
- .setText("second")
- .build();
- ConversationActions.Message thirdMessage =
- new ConversationActions.Message.Builder(PERSON_USER_OTHERS)
- .setText("third")
- .setReferenceTime(createZonedDateTimeFromMsUtc(2000))
- .build();
-
- ActionsSuggestionsModel.ConversationMessage[] conversationMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- Arrays.asList(firstMessage, secondMessage, thirdMessage),
- LANGUAGE_DETECTOR);
-
- assertThat(conversationMessages).hasLength(3);
- assertNativeMessage(conversationMessages[0], firstMessage.getText(), 1, 1000);
- assertNativeMessage(conversationMessages[1], secondMessage.getText(), 1, 0);
- assertNativeMessage(conversationMessages[2], thirdMessage.getText(), 1, 2000);
- }
-
- @Test
- public void testDeduplicateActions() {
- Bundle phoneExtras = new Bundle();
- Intent phoneIntent = new Intent();
- phoneIntent.setComponent(new ComponentName("phone", "intent"));
- ExtrasUtils.putActionIntent(phoneExtras, phoneIntent);
-
- Bundle anotherPhoneExtras = new Bundle();
- Intent anotherPhoneIntent = new Intent();
- anotherPhoneIntent.setComponent(new ComponentName("phone", "another.intent"));
- ExtrasUtils.putActionIntent(anotherPhoneExtras, anotherPhoneIntent);
-
- Bundle urlExtras = new Bundle();
- Intent urlIntent = new Intent();
- urlIntent.setComponent(new ComponentName("url", "intent"));
- ExtrasUtils.putActionIntent(urlExtras, urlIntent);
-
- PendingIntent pendingIntent =
- PendingIntent.getActivity(
- InstrumentationRegistry.getTargetContext(), 0, phoneIntent, 0);
- Icon icon = Icon.createWithData(new byte[0], 0, 0);
- ConversationAction action =
- new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
- .setAction(new RemoteAction(icon, "label", "1", pendingIntent))
- .setExtras(phoneExtras)
- .build();
- ConversationAction actionWithSameLabel =
- new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
- .setAction(new RemoteAction(icon, "label", "2", pendingIntent))
- .setExtras(phoneExtras)
- .build();
- ConversationAction actionWithSamePackageButDifferentClass =
- new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
- .setAction(new RemoteAction(icon, "label", "3", pendingIntent))
- .setExtras(anotherPhoneExtras)
- .build();
- ConversationAction actionWithDifferentLabel =
- new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
- .setAction(new RemoteAction(icon, "another_label", "4", pendingIntent))
- .setExtras(phoneExtras)
- .build();
- ConversationAction actionWithDifferentPackage =
- new ConversationAction.Builder(ConversationAction.TYPE_OPEN_URL)
- .setAction(new RemoteAction(icon, "label", "5", pendingIntent))
- .setExtras(urlExtras)
- .build();
- ConversationAction actionWithoutRemoteAction =
- new ConversationAction.Builder(ConversationAction.TYPE_CREATE_REMINDER).build();
-
- List<ConversationAction> conversationActions =
- ActionsSuggestionsHelper.removeActionsWithDuplicates(
- Arrays.asList(
- action,
- actionWithSameLabel,
- actionWithSamePackageButDifferentClass,
- actionWithDifferentLabel,
- actionWithDifferentPackage,
- actionWithoutRemoteAction));
-
- assertThat(conversationActions).hasSize(3);
- assertThat(conversationActions.get(0).getAction().getContentDescription()).isEqualTo("4");
- assertThat(conversationActions.get(1).getAction().getContentDescription()).isEqualTo("5");
- assertThat(conversationActions.get(2).getAction()).isNull();
- }
-
- @Test
- public void testDeduplicateActions_nullComponent() {
- Bundle phoneExtras = new Bundle();
- Intent phoneIntent = new Intent(Intent.ACTION_DIAL);
- ExtrasUtils.putActionIntent(phoneExtras, phoneIntent);
- PendingIntent pendingIntent =
- PendingIntent.getActivity(
- InstrumentationRegistry.getTargetContext(), 0, phoneIntent, 0);
- Icon icon = Icon.createWithData(new byte[0], 0, 0);
- ConversationAction action =
- new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
- .setAction(new RemoteAction(icon, "label", "1", pendingIntent))
- .setExtras(phoneExtras)
- .build();
- ConversationAction actionWithSameLabel =
- new ConversationAction.Builder(ConversationAction.TYPE_CALL_PHONE)
- .setAction(new RemoteAction(icon, "label", "2", pendingIntent))
- .setExtras(phoneExtras)
- .build();
-
- List<ConversationAction> conversationActions =
- ActionsSuggestionsHelper.removeActionsWithDuplicates(
- Arrays.asList(action, actionWithSameLabel));
-
- assertThat(conversationActions).isEmpty();
- }
-
- public void createLabeledIntentResult_null() {
- ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
- new ActionsSuggestionsModel.ActionSuggestion(
- "text", ConversationAction.TYPE_OPEN_URL, 1.0f, null, null, null);
-
- LabeledIntent.Result labeledIntentResult =
- ActionsSuggestionsHelper.createLabeledIntentResult(
- InstrumentationRegistry.getTargetContext(),
- new TemplateIntentFactory(),
- nativeSuggestion);
-
- assertThat(labeledIntentResult).isNull();
- }
-
- @Test
- public void createLabeledIntentResult_emptyList() {
- ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
- new ActionsSuggestionsModel.ActionSuggestion(
- "text",
- ConversationAction.TYPE_OPEN_URL,
- 1.0f,
- null,
- null,
- new RemoteActionTemplate[0]);
-
- LabeledIntent.Result labeledIntentResult =
- ActionsSuggestionsHelper.createLabeledIntentResult(
- InstrumentationRegistry.getTargetContext(),
- new TemplateIntentFactory(),
- nativeSuggestion);
-
- assertThat(labeledIntentResult).isNull();
- }
-
- @Test
- public void createLabeledIntentResult() {
- ActionsSuggestionsModel.ActionSuggestion nativeSuggestion =
- new ActionsSuggestionsModel.ActionSuggestion(
- "text",
- ConversationAction.TYPE_OPEN_URL,
- 1.0f,
- null,
- null,
- new RemoteActionTemplate[] {
- new RemoteActionTemplate(
- "title",
- null,
- "description",
- null,
- Intent.ACTION_VIEW,
- Uri.parse("http://www.android.com").toString(),
- null,
- 0,
- null,
- null,
- null,
- 0)
- });
-
- LabeledIntent.Result labeledIntentResult =
- ActionsSuggestionsHelper.createLabeledIntentResult(
- InstrumentationRegistry.getTargetContext(),
- new TemplateIntentFactory(),
- nativeSuggestion);
-
- assertThat(labeledIntentResult.remoteAction.getTitle()).isEqualTo("title");
- assertThat(labeledIntentResult.resolvedIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
- }
-
- private ZonedDateTime createZonedDateTimeFromMsUtc(long msUtc) {
- return ZonedDateTime.ofInstant(Instant.ofEpochMilli(msUtc), ZoneId.of("UTC"));
- }
-
- private static void assertNativeMessage(
- ActionsSuggestionsModel.ConversationMessage nativeMessage,
- CharSequence text,
- int userId,
- long referenceTimeInMsUtc) {
- assertThat(nativeMessage.getText()).isEqualTo(text.toString());
- assertThat(nativeMessage.getUserId()).isEqualTo(userId);
- assertThat(nativeMessage.getDetectedTextLanguageTags()).isEqualTo(LOCALE_TAG);
- assertThat(nativeMessage.getReferenceTimeMsUtc()).isEqualTo(referenceTimeInMsUtc);
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/FakeContextBuilder.java b/java/tests/unittests/src/com/android/textclassifier/FakeContextBuilder.java
deleted file mode 100644
index 451229a..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/FakeContextBuilder.java
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier;
-
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyInt;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-import android.annotation.Nullable;
-import android.content.ComponentName;
-import android.content.Context;
-import android.content.ContextWrapper;
-import android.content.Intent;
-import android.content.pm.ActivityInfo;
-import android.content.pm.ApplicationInfo;
-import android.content.pm.PackageManager;
-import android.content.pm.ResolveInfo;
-
-import androidx.test.InstrumentationRegistry;
-
-import com.google.common.base.Preconditions;
-
-import org.mockito.stubbing.Answer;
-
-import java.util.HashMap;
-import java.util.Map;
-import java.util.UUID;
-
-/** A builder used to build a fake context for testing. */
-public final class FakeContextBuilder {
-
- /** A component name that can be used for tests. */
- public static final ComponentName DEFAULT_COMPONENT = new ComponentName("pkg", "cls");
-
- private final PackageManager mPackageManager;
- private final ContextWrapper mContext;
- private final Map<String, ComponentName> mComponents = new HashMap<>();
- private final Map<String, CharSequence> mAppLabels = new HashMap<>();
- private @Nullable ComponentName mAllIntentComponent;
-
- public FakeContextBuilder() {
- mPackageManager = mock(PackageManager.class);
- when(mPackageManager.resolveActivity(any(Intent.class), anyInt())).thenReturn(null);
- mContext =
- new ContextWrapper(InstrumentationRegistry.getTargetContext()) {
- @Override
- public PackageManager getPackageManager() {
- return mPackageManager;
- }
- };
- }
-
- /**
- * Sets the component name of an activity to handle the specified intent action.
- *
- * <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
- */
- public FakeContextBuilder setIntentComponent(
- String intentAction, @Nullable ComponentName component) {
- Preconditions.checkNotNull(intentAction);
- mComponents.put(intentAction, component);
- return this;
- }
-
- /** Sets the app label res for a specified package. */
- public FakeContextBuilder setAppLabel(String packageName, @Nullable CharSequence appLabel) {
- Preconditions.checkNotNull(packageName);
- mAppLabels.put(packageName, appLabel);
- return this;
- }
-
- /**
- * Sets the component name of an activity to handle all intents.
- *
- * <p><strong>NOTE: </strong>By default, no component is set to handle any intent.
- */
- public FakeContextBuilder setAllIntentComponent(@Nullable ComponentName component) {
- mAllIntentComponent = component;
- return this;
- }
-
- /** Builds and returns a fake context. */
- public Context build() {
- when(mPackageManager.resolveActivity(any(Intent.class), anyInt()))
- .thenAnswer(
- (Answer<ResolveInfo>)
- invocation -> {
- final String action =
- ((Intent) invocation.getArgument(0)).getAction();
- final ComponentName component =
- mComponents.containsKey(action)
- ? mComponents.get(action)
- : mAllIntentComponent;
- return getResolveInfo(component);
- });
- when(mPackageManager.getApplicationLabel(any(ApplicationInfo.class)))
- .thenAnswer(
- (Answer<CharSequence>)
- invocation -> {
- ApplicationInfo applicationInfo = invocation.getArgument(0);
- return mAppLabels.get(applicationInfo.packageName);
- });
- return mContext;
- }
-
- /** Returns a component name with random package and class names. */
- public static ComponentName newComponent() {
- return new ComponentName(UUID.randomUUID().toString(), UUID.randomUUID().toString());
- }
-
- private static ResolveInfo getResolveInfo(ComponentName component) {
- final ResolveInfo info;
- if (component == null) {
- info = null;
- } else {
- // NOTE: If something breaks in TextClassifier because we expect more fields to be set
- // in here, just add them.
- info = new ResolveInfo();
- info.activityInfo = new ActivityInfo();
- info.activityInfo.packageName = component.getPackageName();
- info.activityInfo.name = component.getClassName();
- info.activityInfo.exported = true;
- info.activityInfo.applicationInfo = new ApplicationInfo();
- info.activityInfo.applicationInfo.packageName = component.getPackageName();
- info.activityInfo.applicationInfo.icon = 0;
- }
- return info;
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/unittests/src/com/android/textclassifier/ModelFileManagerTest.java
deleted file mode 100644
index c2626ff..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ModelFileManagerTest.java
+++ /dev/null
@@ -1,373 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.Mockito.when;
-
-import android.os.LocaleList;
-
-import androidx.test.InstrumentationRegistry;
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Locale;
-import java.util.function.Supplier;
-import java.util.stream.Collectors;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class ModelFileManagerTest {
- private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
- @Mock private Supplier<List<ModelFileManager.ModelFile>> mModelFileSupplier;
- private ModelFileManager.ModelFileSupplierImpl mModelFileSupplierImpl;
- private ModelFileManager mModelFileManager;
- private File mRootTestDir;
- private File mFactoryModelDir;
- private File mUpdatedModelFile;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
- mModelFileManager = new ModelFileManager(mModelFileSupplier);
- mRootTestDir = InstrumentationRegistry.getTargetContext().getCacheDir();
- mFactoryModelDir = new File(mRootTestDir, "factory");
- mUpdatedModelFile = new File(mRootTestDir, "updated.model");
-
- mModelFileSupplierImpl =
- new ModelFileManager.ModelFileSupplierImpl(
- mFactoryModelDir,
- "test\\d.model",
- mUpdatedModelFile,
- fd -> 1,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
-
- mRootTestDir.mkdirs();
- mFactoryModelDir.mkdirs();
-
- Locale.setDefault(DEFAULT_LOCALE);
- }
-
- @After
- public void removeTestDir() {
- recursiveDelete(mRootTestDir);
- }
-
- @Test
- public void get() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- new File("/path/a"), 1, Collections.emptyList(), "", true);
- when(mModelFileSupplier.get()).thenReturn(Collections.singletonList(modelFile));
-
- List<ModelFileManager.ModelFile> modelFiles = mModelFileManager.listModelFiles();
-
- assertThat(modelFiles).hasSize(1);
- assertThat(modelFiles.get(0)).isEqualTo(modelFile);
- }
-
- @Test
- public void findBestModel_versionCode() {
- ModelFileManager.ModelFile olderModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/a"), 1, Collections.emptyList(), "", true);
-
- ModelFileManager.ModelFile newerModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/b"), 2, Collections.emptyList(), "", true);
- when(mModelFileSupplier.get()).thenReturn(Arrays.asList(olderModelFile, newerModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- mModelFileManager.findBestModelFile(LocaleList.getEmptyLocaleList());
-
- assertThat(bestModelFile).isEqualTo(newerModelFile);
- }
-
- @Test
- public void findBestModel_languageDependentModelIsPreferred() {
- Locale locale = Locale.forLanguageTag("ja");
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/a"), 1, Collections.emptyList(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
- when(mModelFileSupplier.get())
- .thenReturn(
- Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- mModelFileManager.findBestModelFile(
- LocaleList.forLanguageTags(locale.toLanguageTag()));
- assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel() {
- Locale locale = Locale.forLanguageTag("ja");
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/a"), 1, Collections.emptyList(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
-
- when(mModelFileSupplier.get())
- .thenReturn(
- Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- mModelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/a"), 1, Collections.emptyList(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(DEFAULT_LOCALE),
- DEFAULT_LOCALE.toLanguageTag(),
- false);
-
- when(mModelFileSupplier.get())
- .thenReturn(
- Arrays.asList(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- mModelFileManager.findBestModelFile(LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_languageIsMoreImportantThanVersion() {
- ModelFileManager.ModelFile matchButOlderModel =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("fr")),
- "fr",
- false);
-
- ModelFileManager.ModelFile mismatchButNewerModel =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- when(mModelFileSupplier.get())
- .thenReturn(Arrays.asList(matchButOlderModel, mismatchButNewerModel));
-
- ModelFileManager.ModelFile bestModelFile =
- mModelFileManager.findBestModelFile(LocaleList.forLanguageTags("fr"));
- assertThat(bestModelFile).isEqualTo(matchButOlderModel);
- }
-
- @Test
- public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
- ModelFileManager.ModelFile matchLocaleModel =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile languageIndependentModel =
- new ModelFileManager.ModelFile(
- new File("/path/a"), 2, Collections.emptyList(), "", true);
- when(mModelFileSupplier.get())
- .thenReturn(Arrays.asList(matchLocaleModel, languageIndependentModel));
-
- ModelFileManager.ModelFile bestModelFile =
- mModelFileManager.findBestModelFile(LocaleList.forLanguageTags("ja"));
-
- assertThat(bestModelFile).isEqualTo(matchLocaleModel);
- }
-
- @Test
- public void modelFileEquals() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA).isEqualTo(modelB);
- }
-
- @Test
- public void modelFile_different() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA).isNotEqualTo(modelB);
- }
-
- @Test
- public void modelFile_getPath() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getPath()).isEqualTo("/path/a");
- }
-
- @Test
- public void modelFile_getName() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getName()).isEqualTo("a");
- }
-
- @Test
- public void modelFile_isPreferredTo_languageDependentIsBetter() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- new File("/path/b"), 2, Collections.emptyList(), "", true);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void modelFile_isPreferredTo_version() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- new File("/path/a"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- new File("/path/b"), 1, Collections.emptyList(), "", false);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void testFileSupplierImpl_updatedFileOnly() throws IOException {
- mUpdatedModelFile.createNewFile();
- File model1 = new File(mFactoryModelDir, "test1.model");
- model1.createNewFile();
- File model2 = new File(mFactoryModelDir, "test2.model");
- model2.createNewFile();
- new File(mFactoryModelDir, "not_match_regex.model").createNewFile();
-
- List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream()
- .map(modelFile -> modelFile.getPath())
- .collect(Collectors.toList());
-
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- mUpdatedModelFile.getAbsolutePath(),
- model1.getAbsolutePath(),
- model2.getAbsolutePath());
- }
-
- @Test
- public void testFileSupplierImpl_empty() {
- mFactoryModelDir.delete();
- List<ModelFileManager.ModelFile> modelFiles = mModelFileSupplierImpl.get();
-
- assertThat(modelFiles).hasSize(0);
- }
-
- private static void recursiveDelete(File f) {
- if (f.isDirectory()) {
- for (File innerFile : f.listFiles()) {
- recursiveDelete(innerFile);
- }
- }
- f.delete();
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/StringUtilsTest.java b/java/tests/unittests/src/com/android/textclassifier/StringUtilsTest.java
deleted file mode 100644
index c3fc717..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/StringUtilsTest.java
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.testng.Assert.assertThrows;
-
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class StringUtilsTest {
-
- @Test
- public void testGetSubString() {
- final String text = "Yakuza call themselves 任侠団体";
- int start;
- int end;
- int minimumLength;
-
- // End index at end of text.
- start = text.indexOf("任侠団体");
- end = text.length();
- minimumLength = 20;
- assertThat(StringUtils.getSubString(text, start, end, minimumLength))
- .isEqualTo("call themselves 任侠団体");
-
- // Start index at beginning of text.
- start = 0;
- end = "Yakuza".length();
- minimumLength = 15;
- assertThat(StringUtils.getSubString(text, start, end, minimumLength))
- .isEqualTo("Yakuza call themselves");
-
- // Text in the middle
- start = text.indexOf("all");
- end = start + 1;
- minimumLength = 10;
- assertThat(StringUtils.getSubString(text, start, end, minimumLength))
- .isEqualTo("Yakuza call themselves");
-
- // Selection >= minimumLength.
- start = text.indexOf("themselves");
- end = start + "themselves".length();
- minimumLength = end - start;
- assertThat(StringUtils.getSubString(text, start, end, minimumLength))
- .isEqualTo("themselves");
-
- // text.length < minimumLength.
- minimumLength = text.length() + 1;
- assertThat(StringUtils.getSubString(text, start, end, minimumLength)).isEqualTo(text);
- }
-
- @Test
- public void testGetSubString_invalidParams() {
- final String text = "The Yoruba regard Olodumare as the principal agent of creation";
- final int length = text.length();
- final int minimumLength = 10;
-
- // Null text
- assertThrows(() -> StringUtils.getSubString(null, 0, 1, minimumLength));
- // start > end
- assertThrows(() -> StringUtils.getSubString(text, 6, 5, minimumLength));
- // start < 0
- assertThrows(() -> StringUtils.getSubString(text, -1, 5, minimumLength));
- // end > text.length
- assertThrows(() -> StringUtils.getSubString(text, 6, length + 1, minimumLength));
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/TextClassificationConstantsTest.java b/java/tests/unittests/src/com/android/textclassifier/TextClassificationConstantsTest.java
deleted file mode 100644
index b7c9246..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/TextClassificationConstantsTest.java
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier;
-
-import static com.google.common.truth.Truth.assertWithMessage;
-
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import org.junit.Assert;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class TextClassificationConstantsTest {
-
- private static final float EPSILON = 0.0001f;
-
- @Test
- public void testLoadFromString_defaultValues() {
- final TextClassificationConstants constants = new TextClassificationConstants();
-
- assertWithMessage("suggest_selection_max_range_length")
- .that(constants.getSuggestSelectionMaxRangeLength())
- .isEqualTo(10 * 1000);
- assertWithMessage("classify_text_max_range_length")
- .that(constants.getClassifyTextMaxRangeLength())
- .isEqualTo(10 * 1000);
- assertWithMessage("generate_links_max_text_length")
- .that(constants.getGenerateLinksMaxTextLength())
- .isEqualTo(100 * 1000);
- // assertWithMessage("generate_links_log_sample_rate")
- // .that(constants.getGenerateLinksLogSampleRate()).isEqualTo(100);
- assertWithMessage("entity_list_default")
- .that(constants.getEntityListDefault())
- .containsExactly("address", "email", "url", "phone", "date", "datetime", "flight");
- assertWithMessage("entity_list_not_editable")
- .that(constants.getEntityListNotEditable())
- .containsExactly("address", "email", "url", "phone", "date", "datetime", "flight");
- assertWithMessage("entity_list_editable")
- .that(constants.getEntityListEditable())
- .containsExactly("address", "email", "url", "phone", "date", "datetime", "flight");
- assertWithMessage("in_app_conversation_action_types_default")
- .that(constants.getInAppConversationActionTypes())
- .containsExactly(
- "text_reply",
- "create_reminder",
- "call_phone",
- "open_url",
- "send_email",
- "send_sms",
- "track_flight",
- "view_calendar",
- "view_map",
- "add_contact",
- "copy");
- assertWithMessage("notification_conversation_action_types_default")
- .that(constants.getNotificationConversationActionTypes())
- .containsExactly(
- "text_reply",
- "create_reminder",
- "call_phone",
- "open_url",
- "send_email",
- "send_sms",
- "track_flight",
- "view_calendar",
- "view_map",
- "add_contact",
- "copy");
- assertWithMessage("lang_id_threshold_override")
- .that(constants.getLangIdThresholdOverride())
- .isWithin(EPSILON)
- .of(-1f);
- Assert.assertArrayEquals(
- "lang_id_context_settings",
- constants.getLangIdContextSettings(),
- new float[] {20, 1, 0.4f},
- EPSILON);
- assertWithMessage("detect_language_from_text_enabled")
- .that(constants.isDetectLanguagesFromTextEnabled())
- .isTrue();
- assertWithMessage("template_intent_factory_enabled")
- .that(constants.isTemplateIntentFactoryEnabled())
- .isTrue();
- assertWithMessage("translate_in_classification_enabled")
- .that(constants.isTranslateInClassificationEnabled())
- .isTrue();
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/unittests/src/com/android/textclassifier/TextClassifierImplTest.java
deleted file mode 100644
index 6faba34..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/TextClassifierImplTest.java
+++ /dev/null
@@ -1,681 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier;
-
-import static org.hamcrest.CoreMatchers.not;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertThat;
-import static org.junit.Assert.assertTrue;
-
-import android.app.RemoteAction;
-import android.content.Context;
-import android.content.Intent;
-import android.net.Uri;
-import android.os.Bundle;
-import android.os.LocaleList;
-import android.text.Spannable;
-import android.text.SpannableString;
-import android.view.textclassifier.ConversationAction;
-import android.view.textclassifier.ConversationActions;
-import android.view.textclassifier.TextClassification;
-import android.view.textclassifier.TextClassifier;
-import android.view.textclassifier.TextLanguage;
-import android.view.textclassifier.TextLinks;
-import android.view.textclassifier.TextSelection;
-
-import androidx.test.InstrumentationRegistry;
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import com.google.common.truth.Truth;
-
-import org.hamcrest.BaseMatcher;
-import org.hamcrest.Description;
-import org.hamcrest.Matcher;
-import org.junit.Before;
-import org.junit.Ignore;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-/**
- * Testing {@link TextClassifierImplTest} APIs on local and system textclassifier.
- *
- * <p>Tests are skipped if such a textclassifier does not exist.
- */
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class TextClassifierImplTest {
-
- // TODO: Implement TextClassifierService testing.
-
- private static final LocaleList LOCALES = LocaleList.forLanguageTags("en-US");
- private static final String NO_TYPE = null;
-
- private Context mContext;
- private TextClassifierImpl mClassifier;
-
- @Before
- public void setup() {
- mContext = InstrumentationRegistry.getTargetContext();
- mClassifier = new TextClassifierImpl(mContext, new TextClassificationConstants());
- }
-
- @Test
- public void testSuggestSelection() {
- String text = "Contact me at droid@android.com";
- String selected = "droid";
- String suggested = "droid@android.com";
- int startIndex = text.indexOf(selected);
- int endIndex = startIndex + selected.length();
- int smartStartIndex = text.indexOf(suggested);
- int smartEndIndex = smartStartIndex + suggested.length();
- TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextSelection selection = mClassifier.suggestSelection(request);
- assertThat(
- selection,
- isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_EMAIL));
- }
-
- @Test
- public void testSuggestSelection_url() {
- String text = "Visit http://www.android.com for more information";
- String selected = "http";
- String suggested = "http://www.android.com";
- int startIndex = text.indexOf(selected);
- int endIndex = startIndex + selected.length();
- int smartStartIndex = text.indexOf(suggested);
- int smartEndIndex = smartStartIndex + suggested.length();
- TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextSelection selection = mClassifier.suggestSelection(request);
- assertThat(
- selection,
- isTextSelection(smartStartIndex, smartEndIndex, TextClassifier.TYPE_URL));
- }
-
- @Test
- public void testSmartSelection_withEmoji() {
- String text = "\uD83D\uDE02 Hello.";
- String selected = "Hello";
- int startIndex = text.indexOf(selected);
- int endIndex = startIndex + selected.length();
- TextSelection.Request request =
- new TextSelection.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextSelection selection = mClassifier.suggestSelection(request);
- assertThat(selection, isTextSelection(startIndex, endIndex, NO_TYPE));
- }
-
- @Test
- public void testClassifyText() {
- String text = "Contact me at droid@android.com";
- String classifiedText = "droid@android.com";
- int startIndex = text.indexOf(classifiedText);
- int endIndex = startIndex + classifiedText.length();
- TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextClassification classification = mClassifier.classifyText(request);
- assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_EMAIL));
- }
-
- @Test
- public void testClassifyText_url() {
- String text = "Visit www.android.com for more information";
- String classifiedText = "www.android.com";
- int startIndex = text.indexOf(classifiedText);
- int endIndex = startIndex + classifiedText.length();
- TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextClassification classification = mClassifier.classifyText(request);
- assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
- assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
- }
-
- @Test
- public void testClassifyText_address() {
- String text = "Brandschenkestrasse 110, Zürich, Switzerland";
- TextClassification.Request request =
- new TextClassification.Request.Builder(text, 0, text.length())
- .setDefaultLocales(LOCALES)
- .build();
-
- TextClassification classification = mClassifier.classifyText(request);
- assertThat(classification, isTextClassification(text, TextClassifier.TYPE_ADDRESS));
- }
-
- @Test
- public void testClassifyText_url_inCaps() {
- String text = "Visit HTTP://ANDROID.COM for more information";
- String classifiedText = "HTTP://ANDROID.COM";
- int startIndex = text.indexOf(classifiedText);
- int endIndex = startIndex + classifiedText.length();
- TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextClassification classification = mClassifier.classifyText(request);
- assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_URL));
- assertThat(classification, containsIntentWithAction(Intent.ACTION_VIEW));
- }
-
- @Test
- public void testClassifyText_date() {
- String text = "Let's meet on January 9, 2018.";
- String classifiedText = "January 9, 2018";
- int startIndex = text.indexOf(classifiedText);
- int endIndex = startIndex + classifiedText.length();
- TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextClassification classification = mClassifier.classifyText(request);
- assertThat(classification, isTextClassification(classifiedText, TextClassifier.TYPE_DATE));
- Bundle extras = classification.getExtras();
- List<Bundle> entities = ExtrasUtils.getEntities(extras);
- Truth.assertThat(entities).hasSize(1);
- Bundle entity = entities.get(0);
- Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_DATE);
- }
-
- @Test
- public void testClassifyText_datetime() {
- String text = "Let's meet 2018/01/01 10:30:20.";
- String classifiedText = "2018/01/01 10:30:20";
- int startIndex = text.indexOf(classifiedText);
- int endIndex = startIndex + classifiedText.length();
- TextClassification.Request request =
- new TextClassification.Request.Builder(text, startIndex, endIndex)
- .setDefaultLocales(LOCALES)
- .build();
-
- TextClassification classification = mClassifier.classifyText(request);
- assertThat(
- classification,
- isTextClassification(classifiedText, TextClassifier.TYPE_DATE_TIME));
- }
-
- @Test
- public void testClassifyText_foreignText() {
- LocaleList originalLocales = LocaleList.getDefault();
- LocaleList.setDefault(LocaleList.forLanguageTags("en"));
- String japaneseText = "これは日本語のテキストです";
-
- Context context =
- new FakeContextBuilder()
- .setIntentComponent(
- Intent.ACTION_TRANSLATE, FakeContextBuilder.DEFAULT_COMPONENT)
- .build();
- TextClassifierImpl textClassifier =
- new TextClassifierImpl(context, new TextClassificationConstants());
- TextClassification.Request request =
- new TextClassification.Request.Builder(japaneseText, 0, japaneseText.length())
- .setDefaultLocales(LOCALES)
- .build();
-
- TextClassification classification = textClassifier.classifyText(request);
- RemoteAction translateAction = classification.getActions().get(0);
- assertEquals(1, classification.getActions().size());
- assertEquals(
- context.getString(com.android.textclassifier.R.string.translate),
- translateAction.getTitle());
-
- assertEquals(translateAction, ExtrasUtils.findTranslateAction(classification));
- Intent intent = ExtrasUtils.getActionsIntents(classification).get(0);
- assertEquals(Intent.ACTION_TRANSLATE, intent.getAction());
- Bundle foreignLanguageInfo = ExtrasUtils.getForeignLanguageExtra(classification);
- assertEquals("ja", ExtrasUtils.getEntityType(foreignLanguageInfo));
- assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) >= 0);
- assertTrue(ExtrasUtils.getScore(foreignLanguageInfo) <= 1);
- assertTrue(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER));
- assertEquals("ja", ExtrasUtils.getTopLanguage(intent).getLanguage());
-
- LocaleList.setDefault(originalLocales);
- }
-
- @Test
- public void testGenerateLinks_phone() {
- String text = "The number is +12122537077. See you tonight!";
- TextLinks.Request request = new TextLinks.Request.Builder(text).build();
- assertThat(
- mClassifier.generateLinks(request),
- isTextLinksContaining(text, "+12122537077", TextClassifier.TYPE_PHONE));
- }
-
- @Test
- public void testGenerateLinks_exclude() {
- String text = "You want apple@banana.com. See you tonight!";
- List<String> hints = Collections.EMPTY_LIST;
- List<String> included = Collections.EMPTY_LIST;
- List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
- TextLinks.Request request =
- new TextLinks.Request.Builder(text)
- .setEntityConfig(
- TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
- .build();
- assertThat(
- mClassifier.generateLinks(request),
- not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
- }
-
- @Test
- public void testGenerateLinks_explicit_address() {
- String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
- List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
- TextLinks.Request request =
- new TextLinks.Request.Builder(text)
- .setEntityConfig(
- TextClassifier.EntityConfig.createWithExplicitEntityList(explicit))
- .setDefaultLocales(LOCALES)
- .build();
- assertThat(
- mClassifier.generateLinks(request),
- isTextLinksContaining(
- text,
- "1600 Amphitheater Parkway, Mountain View, CA",
- TextClassifier.TYPE_ADDRESS));
- }
-
- @Test
- public void testGenerateLinks_exclude_override() {
- String text = "You want apple@banana.com. See you tonight!";
- List<String> hints = Collections.EMPTY_LIST;
- List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
- List<String> excluded = Arrays.asList(TextClassifier.TYPE_EMAIL);
- TextLinks.Request request =
- new TextLinks.Request.Builder(text)
- .setEntityConfig(
- TextClassifier.EntityConfig.create(hints, included, excluded))
- .setDefaultLocales(LOCALES)
- .build();
- assertThat(
- mClassifier.generateLinks(request),
- not(isTextLinksContaining(text, "apple@banana.com", TextClassifier.TYPE_EMAIL)));
- }
-
- @Test
- public void testGenerateLinks_maxLength() {
- char[] manySpaces = new char[mClassifier.getMaxGenerateLinksTextLength()];
- Arrays.fill(manySpaces, ' ');
- TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
- TextLinks links = mClassifier.generateLinks(request);
- assertTrue(links.getLinks().isEmpty());
- }
-
- @Test
- public void testApplyLinks_unsupportedCharacter() {
- Spannable url = new SpannableString("\u202Emoc.diordna.com");
- TextLinks.Request request = new TextLinks.Request.Builder(url).build();
- assertEquals(
- TextLinks.STATUS_UNSUPPORTED_CHARACTER,
- mClassifier.generateLinks(request).apply(url, 0, null));
- }
-
- @Test(expected = IllegalArgumentException.class)
- public void testGenerateLinks_tooLong() {
- char[] manySpaces = new char[mClassifier.getMaxGenerateLinksTextLength() + 1];
- Arrays.fill(manySpaces, ' ');
- TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
- mClassifier.generateLinks(request);
- }
-
- @Test
- public void testGenerateLinks_entityData() {
- String text = "The number is +12122537077.";
- Bundle extras = new Bundle();
- ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
- TextLinks.Request request = new TextLinks.Request.Builder(text).setExtras(extras).build();
-
- TextLinks textLinks = mClassifier.generateLinks(request);
-
- Truth.assertThat(textLinks.getLinks()).hasSize(1);
- TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
- List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
- Truth.assertThat(entities).hasSize(1);
- Bundle entity = entities.get(0);
- Truth.assertThat(ExtrasUtils.getEntityType(entity)).isEqualTo(TextClassifier.TYPE_PHONE);
- }
-
- @Test
- public void testGenerateLinks_entityData_disabled() {
- String text = "The number is +12122537077.";
- TextLinks.Request request = new TextLinks.Request.Builder(text).build();
-
- TextLinks textLinks = mClassifier.generateLinks(request);
-
- Truth.assertThat(textLinks.getLinks()).hasSize(1);
- TextLinks.TextLink textLink = textLinks.getLinks().iterator().next();
- List<Bundle> entities = ExtrasUtils.getEntities(textLink.getExtras());
- Truth.assertThat(entities).isNull();
- }
-
- @Test
- public void testDetectLanguage() {
- String text = "This is English text";
- TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
- TextLanguage textLanguage = mClassifier.detectLanguage(request);
- assertThat(textLanguage, isTextLanguage("en"));
- }
-
- @Test
- public void testDetectLanguage_japanese() {
- String text = "これは日本語のテキストです";
- TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
- TextLanguage textLanguage = mClassifier.detectLanguage(request);
- assertThat(textLanguage, isTextLanguage("ja"));
- }
-
- @Ignore // Doesn't work without a language-based model.
- @Test
- public void testSuggestConversationActions_textReplyOnly_maxOne() {
- ConversationActions.Message message =
- new ConversationActions.Message.Builder(
- ConversationActions.Message.PERSON_USER_OTHERS)
- .setText("Where are you?")
- .build();
- TextClassifier.EntityConfig typeConfig =
- new TextClassifier.EntityConfig.Builder()
- .includeTypesFromTextClassifier(false)
- .setIncludedTypes(
- Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
- .build();
- ConversationActions.Request request =
- new ConversationActions.Request.Builder(Collections.singletonList(message))
- .setMaxSuggestions(1)
- .setTypeConfig(typeConfig)
- .build();
-
- ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
- Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
- ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
- Truth.assertThat(conversationAction.getType())
- .isEqualTo(ConversationAction.TYPE_TEXT_REPLY);
- Truth.assertThat(conversationAction.getTextReply()).isNotNull();
- }
-
- @Ignore // Doesn't work without a language-based model.
- @Test
- public void testSuggestConversationActions_textReplyOnly_noMax() {
- ConversationActions.Message message =
- new ConversationActions.Message.Builder(
- ConversationActions.Message.PERSON_USER_OTHERS)
- .setText("Where are you?")
- .build();
- TextClassifier.EntityConfig typeConfig =
- new TextClassifier.EntityConfig.Builder()
- .includeTypesFromTextClassifier(false)
- .setIncludedTypes(
- Collections.singletonList(ConversationAction.TYPE_TEXT_REPLY))
- .build();
- ConversationActions.Request request =
- new ConversationActions.Request.Builder(Collections.singletonList(message))
- .setTypeConfig(typeConfig)
- .build();
-
- ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
- assertTrue(conversationActions.getConversationActions().size() > 1);
- for (ConversationAction conversationAction : conversationActions.getConversationActions()) {
- assertThat(
- conversationAction, isConversationAction(ConversationAction.TYPE_TEXT_REPLY));
- }
- }
-
- @Test
- public void testSuggestConversationActions_openUrl() {
- ConversationActions.Message message =
- new ConversationActions.Message.Builder(
- ConversationActions.Message.PERSON_USER_OTHERS)
- .setText("Check this out: https://www.android.com")
- .build();
- TextClassifier.EntityConfig typeConfig =
- new TextClassifier.EntityConfig.Builder()
- .includeTypesFromTextClassifier(false)
- .setIncludedTypes(
- Collections.singletonList(ConversationAction.TYPE_OPEN_URL))
- .build();
- ConversationActions.Request request =
- new ConversationActions.Request.Builder(Collections.singletonList(message))
- .setMaxSuggestions(1)
- .setTypeConfig(typeConfig)
- .build();
-
- ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
- Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
- ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
- Truth.assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
- Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
- Truth.assertThat(actionIntent.getAction()).isEqualTo(Intent.ACTION_VIEW);
- Truth.assertThat(actionIntent.getData()).isEqualTo(Uri.parse("https://www.android.com"));
- }
-
- @Ignore // Doesn't work without a language-based model.
- @Test
- public void testSuggestConversationActions_copy() {
- ConversationActions.Message message =
- new ConversationActions.Message.Builder(
- ConversationActions.Message.PERSON_USER_OTHERS)
- .setText("Authentication code: 12345")
- .build();
- TextClassifier.EntityConfig typeConfig =
- new TextClassifier.EntityConfig.Builder()
- .includeTypesFromTextClassifier(false)
- .setIncludedTypes(Collections.singletonList(ConversationAction.TYPE_COPY))
- .build();
- ConversationActions.Request request =
- new ConversationActions.Request.Builder(Collections.singletonList(message))
- .setMaxSuggestions(1)
- .setTypeConfig(typeConfig)
- .build();
-
- ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
- Truth.assertThat(conversationActions.getConversationActions()).hasSize(1);
- ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
- Truth.assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_COPY);
- Truth.assertThat(conversationAction.getTextReply()).isAnyOf(null, "");
- Truth.assertThat(conversationAction.getAction()).isNull();
- String code = ExtrasUtils.getCopyText(conversationAction.getExtras());
- Truth.assertThat(code).isEqualTo("12345");
- Truth.assertThat(ExtrasUtils.getSerializedEntityData(conversationAction.getExtras()))
- .isNotEmpty();
- }
-
- @Test
- public void testSuggetsConversationActions_deduplicate() {
- ConversationActions.Message message =
- new ConversationActions.Message.Builder(
- ConversationActions.Message.PERSON_USER_OTHERS)
- .setText("a@android.com b@android.com")
- .build();
- ConversationActions.Request request =
- new ConversationActions.Request.Builder(Collections.singletonList(message))
- .setMaxSuggestions(3)
- .build();
-
- ConversationActions conversationActions = mClassifier.suggestConversationActions(request);
-
- Truth.assertThat(conversationActions.getConversationActions()).isEmpty();
- }
-
- private static Matcher<TextSelection> isTextSelection(
- final int startIndex, final int endIndex, final String type) {
- return new BaseMatcher<TextSelection>() {
- @Override
- public boolean matches(Object o) {
- if (o instanceof TextSelection) {
- TextSelection selection = (TextSelection) o;
- return startIndex == selection.getSelectionStartIndex()
- && endIndex == selection.getSelectionEndIndex()
- && typeMatches(selection, type);
- }
- return false;
- }
-
- private boolean typeMatches(TextSelection selection, String type) {
- return type == null
- || (selection.getEntityCount() > 0
- && type.trim().equalsIgnoreCase(selection.getEntity(0)));
- }
-
- @Override
- public void describeTo(Description description) {
- description.appendValue(String.format("%d, %d, %s", startIndex, endIndex, type));
- }
- };
- }
-
- private static Matcher<TextLinks> isTextLinksContaining(
- final String text, final String substring, final String type) {
- return new BaseMatcher<TextLinks>() {
-
- @Override
- public void describeTo(Description description) {
- description
- .appendText("text=")
- .appendValue(text)
- .appendText(", substring=")
- .appendValue(substring)
- .appendText(", type=")
- .appendValue(type);
- }
-
- @Override
- public boolean matches(Object o) {
- if (o instanceof TextLinks) {
- for (TextLinks.TextLink link : ((TextLinks) o).getLinks()) {
- if (text.subSequence(link.getStart(), link.getEnd()).equals(substring)) {
- return type.equals(link.getEntity(0));
- }
- }
- }
- return false;
- }
- };
- }
-
- private static Matcher<TextClassification> isTextClassification(
- final String text, final String type) {
- return new BaseMatcher<TextClassification>() {
- @Override
- public boolean matches(Object o) {
- if (o instanceof TextClassification) {
- TextClassification result = (TextClassification) o;
- return text.equals(result.getText())
- && result.getEntityCount() > 0
- && type.equals(result.getEntity(0));
- }
- return false;
- }
-
- @Override
- public void describeTo(Description description) {
- description
- .appendText("text=")
- .appendValue(text)
- .appendText(", type=")
- .appendValue(type);
- }
- };
- }
-
- private static Matcher<TextClassification> containsIntentWithAction(final String action) {
- return new BaseMatcher<TextClassification>() {
- @Override
- public boolean matches(Object o) {
- if (o instanceof TextClassification) {
- TextClassification result = (TextClassification) o;
- return ExtrasUtils.findAction(result, action) != null;
- }
- return false;
- }
-
- @Override
- public void describeTo(Description description) {
- description.appendText("intent action=").appendValue(action);
- }
- };
- }
-
- private static Matcher<TextLanguage> isTextLanguage(final String languageTag) {
- return new BaseMatcher<TextLanguage>() {
- @Override
- public boolean matches(Object o) {
- if (o instanceof TextLanguage) {
- TextLanguage result = (TextLanguage) o;
- return result.getLocaleHypothesisCount() > 0
- && languageTag.equals(result.getLocale(0).toLanguageTag());
- }
- return false;
- }
-
- @Override
- public void describeTo(Description description) {
- description.appendText("locale=").appendValue(languageTag);
- }
- };
- }
-
- private static Matcher<ConversationAction> isConversationAction(String actionType) {
- return new BaseMatcher<ConversationAction>() {
- @Override
- public boolean matches(Object o) {
- if (!(o instanceof ConversationAction)) {
- return false;
- }
- ConversationAction conversationAction = (ConversationAction) o;
- if (!actionType.equals(conversationAction.getType())) {
- return false;
- }
- if (ConversationAction.TYPE_TEXT_REPLY.equals(actionType)) {
- if (conversationAction.getTextReply() == null) {
- return false;
- }
- }
- if (conversationAction.getConfidenceScore() < 0
- || conversationAction.getConfidenceScore() > 1) {
- return false;
- }
- return true;
- }
-
- @Override
- public void describeTo(Description description) {
- description.appendText("actionType=").appendValue(actionType);
- }
- };
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/intent/LabeledIntentTest.java b/java/tests/unittests/src/com/android/textclassifier/intent/LabeledIntentTest.java
deleted file mode 100644
index 98ece06..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/intent/LabeledIntentTest.java
+++ /dev/null
@@ -1,184 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.intent;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.testng.Assert.assertThrows;
-
-import android.content.ComponentName;
-import android.content.Context;
-import android.content.Intent;
-import android.net.Uri;
-import android.os.Bundle;
-import android.view.textclassifier.TextClassifier;
-
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import com.android.textclassifier.FakeContextBuilder;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public final class LabeledIntentTest {
- private static final String TITLE_WITHOUT_ENTITY = "Map";
- private static final String TITLE_WITH_ENTITY = "Map NW14D1";
- private static final String DESCRIPTION = "Check the map";
- private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open map";
- private static final Intent INTENT =
- new Intent(Intent.ACTION_VIEW).setDataAndNormalize(Uri.parse("http://www.android.com"));
- private static final int REQUEST_CODE = 42;
- private static final Bundle TEXT_LANGUAGES_BUNDLE = Bundle.EMPTY;
- private static final String APP_LABEL = "fake";
-
- private Context mContext;
-
- @Before
- public void setup() {
- final ComponentName component = FakeContextBuilder.DEFAULT_COMPONENT;
- mContext =
- new FakeContextBuilder()
- .setIntentComponent(Intent.ACTION_VIEW, component)
- .setAppLabel(component.getPackageName(), APP_LABEL)
- .build();
- }
-
- @Test
- public void resolve_preferTitleWithEntity() {
- LabeledIntent labeledIntent =
- new LabeledIntent(
- TITLE_WITHOUT_ENTITY,
- TITLE_WITH_ENTITY,
- DESCRIPTION,
- null,
- INTENT,
- REQUEST_CODE);
-
- LabeledIntent.Result result =
- labeledIntent.resolve(mContext, /*titleChooser*/ null, TEXT_LANGUAGES_BUNDLE);
-
- assertThat(result).isNotNull();
- assertThat(result.remoteAction.getTitle()).isEqualTo(TITLE_WITH_ENTITY);
- assertThat(result.remoteAction.getContentDescription()).isEqualTo(DESCRIPTION);
- Intent intent = result.resolvedIntent;
- assertThat(intent.getAction()).isEqualTo(intent.getAction());
- assertThat(intent.getComponent()).isNotNull();
- assertThat(intent.hasExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER)).isTrue();
- }
-
- @Test
- public void resolve_useAvailableTitle() {
- LabeledIntent labeledIntent =
- new LabeledIntent(
- TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
-
- LabeledIntent.Result result =
- labeledIntent.resolve(mContext, /*titleChooser*/ null, TEXT_LANGUAGES_BUNDLE);
-
- assertThat(result).isNotNull();
- assertThat(result.remoteAction.getTitle()).isEqualTo(TITLE_WITHOUT_ENTITY);
- assertThat(result.remoteAction.getContentDescription()).isEqualTo(DESCRIPTION);
- Intent intent = result.resolvedIntent;
- assertThat(intent.getAction()).isEqualTo(intent.getAction());
- assertThat(intent.getComponent()).isNotNull();
- }
-
- @Test
- public void resolve_titleChooser() {
- LabeledIntent labeledIntent =
- new LabeledIntent(
- TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
-
- LabeledIntent.Result result =
- labeledIntent.resolve(
- mContext,
- (labeledIntent1, resolveInfo) -> "chooser",
- TEXT_LANGUAGES_BUNDLE);
-
- assertThat(result).isNotNull();
- assertThat(result.remoteAction.getTitle()).isEqualTo("chooser");
- assertThat(result.remoteAction.getContentDescription()).isEqualTo(DESCRIPTION);
- Intent intent = result.resolvedIntent;
- assertThat(intent.getAction()).isEqualTo(intent.getAction());
- assertThat(intent.getComponent()).isNotNull();
- }
-
- @Test
- public void resolve_titleChooserReturnsNull() {
- LabeledIntent labeledIntent =
- new LabeledIntent(
- TITLE_WITHOUT_ENTITY, null, DESCRIPTION, null, INTENT, REQUEST_CODE);
-
- LabeledIntent.Result result =
- labeledIntent.resolve(
- mContext, (labeledIntent1, resolveInfo) -> null, TEXT_LANGUAGES_BUNDLE);
-
- assertThat(result).isNotNull();
- assertThat(result.remoteAction.getTitle()).isEqualTo(TITLE_WITHOUT_ENTITY);
- assertThat(result.remoteAction.getContentDescription()).isEqualTo(DESCRIPTION);
- Intent intent = result.resolvedIntent;
- assertThat(intent.getAction()).isEqualTo(intent.getAction());
- assertThat(intent.getComponent()).isNotNull();
- }
-
- @Test
- public void resolve_missingTitle() {
- assertThrows(
- IllegalArgumentException.class,
- () -> new LabeledIntent(null, null, DESCRIPTION, null, INTENT, REQUEST_CODE));
- }
-
- @Test
- public void resolve_noIntentHandler() {
- // See setup(). mContext can only resolve Intent.ACTION_VIEW.
- Intent unresolvableIntent = new Intent(Intent.ACTION_TRANSLATE);
- LabeledIntent labeledIntent =
- new LabeledIntent(
- TITLE_WITHOUT_ENTITY,
- null,
- DESCRIPTION,
- null,
- unresolvableIntent,
- REQUEST_CODE);
-
- LabeledIntent.Result result = labeledIntent.resolve(mContext, null, null);
-
- assertThat(result).isNull();
- }
-
- @Test
- public void resolve_descriptionWithAppName() {
- LabeledIntent labeledIntent =
- new LabeledIntent(
- TITLE_WITHOUT_ENTITY,
- TITLE_WITH_ENTITY,
- DESCRIPTION,
- DESCRIPTION_WITH_APP_NAME,
- INTENT,
- REQUEST_CODE);
-
- LabeledIntent.Result result =
- labeledIntent.resolve(mContext, /*titleChooser*/ null, TEXT_LANGUAGES_BUNDLE);
-
- assertThat(result).isNotNull();
- assertThat(result.remoteAction.getContentDescription()).isEqualTo("Use fake to open map");
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java b/java/tests/unittests/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
deleted file mode 100644
index 6ae9e5a..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/intent/LegacyIntentClassificationFactoryTest.java
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier.intent;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.content.Intent;
-import android.view.textclassifier.TextClassifier;
-
-import androidx.test.InstrumentationRegistry;
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import com.google.android.textclassifier.AnnotatorModel;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-
-import java.util.List;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class LegacyIntentClassificationFactoryTest {
-
- private static final String TEXT = "text";
-
- private LegacyClassificationIntentFactory mLegacyIntentClassificationFactory;
-
- @Before
- public void setup() {
- mLegacyIntentClassificationFactory = new LegacyClassificationIntentFactory();
- }
-
- @Test
- public void create_typeDictionary() {
- AnnotatorModel.ClassificationResult classificationResult =
- new AnnotatorModel.ClassificationResult(
- TextClassifier.TYPE_DICTIONARY,
- 1.0f,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- 0L,
- 0L,
- 0d);
-
- List<LabeledIntent> intents =
- mLegacyIntentClassificationFactory.create(
- InstrumentationRegistry.getTargetContext(),
- TEXT,
- /* foreignText */ false,
- null,
- classificationResult);
-
- assertThat(intents).hasSize(1);
- LabeledIntent labeledIntent = intents.get(0);
- Intent intent = labeledIntent.intent;
- assertThat(intent.getAction()).isEqualTo(Intent.ACTION_DEFINE);
- assertThat(intent.getStringExtra(Intent.EXTRA_TEXT)).isEqualTo(TEXT);
- }
-
- @Test
- public void create_translateAndDictionary() {
- AnnotatorModel.ClassificationResult classificationResult =
- new AnnotatorModel.ClassificationResult(
- TextClassifier.TYPE_DICTIONARY,
- 1.0f,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- 0L,
- 0L,
- 0d);
-
- List<LabeledIntent> intents =
- mLegacyIntentClassificationFactory.create(
- InstrumentationRegistry.getTargetContext(),
- TEXT,
- /* foreignText */ true,
- null,
- classificationResult);
-
- assertThat(intents).hasSize(2);
- assertThat(intents.get(0).intent.getAction()).isEqualTo(Intent.ACTION_DEFINE);
- assertThat(intents.get(1).intent.getAction()).isEqualTo(Intent.ACTION_TRANSLATE);
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java b/java/tests/unittests/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
deleted file mode 100644
index 32840d0..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/intent/TemplateClassificationIntentFactoryTest.java
+++ /dev/null
@@ -1,245 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-
-package com.android.textclassifier.intent;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.ArgumentMatchers.same;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
-
-import android.content.Context;
-import android.content.Intent;
-import android.view.textclassifier.TextClassifier;
-
-import androidx.test.InstrumentationRegistry;
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import com.google.android.textclassifier.AnnotatorModel;
-import com.google.android.textclassifier.RemoteActionTemplate;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.List;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class TemplateClassificationIntentFactoryTest {
-
- private static final String TEXT = "text";
- private static final String TITLE_WITHOUT_ENTITY = "Map";
- private static final String DESCRIPTION = "Opens in Maps";
- private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open Map";
- private static final String ACTION = Intent.ACTION_VIEW;
-
- @Mock private ClassificationIntentFactory mFallback;
- private TemplateClassificationIntentFactory mTemplateClassificationIntentFactory;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
- mTemplateClassificationIntentFactory =
- new TemplateClassificationIntentFactory(new TemplateIntentFactory(), mFallback);
- }
-
- @Test
- public void create_foreignText() {
- AnnotatorModel.ClassificationResult classificationResult =
- new AnnotatorModel.ClassificationResult(
- TextClassifier.TYPE_ADDRESS,
- 1.0f,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- createRemoteActionTemplates(),
- 0L,
- 0L,
- 0d);
-
- List<LabeledIntent> intents =
- mTemplateClassificationIntentFactory.create(
- InstrumentationRegistry.getTargetContext(),
- TEXT,
- /* foreignText */ true,
- null,
- classificationResult);
-
- assertThat(intents).hasSize(2);
- LabeledIntent labeledIntent = intents.get(0);
- assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
- Intent intent = labeledIntent.intent;
- assertThat(intent.getAction()).isEqualTo(ACTION);
-
- labeledIntent = intents.get(1);
- intent = labeledIntent.intent;
- assertThat(intent.getAction()).isEqualTo(Intent.ACTION_TRANSLATE);
- }
-
- @Test
- public void create_notForeignText() {
- AnnotatorModel.ClassificationResult classificationResult =
- new AnnotatorModel.ClassificationResult(
- TextClassifier.TYPE_ADDRESS,
- 1.0f,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- createRemoteActionTemplates(),
- 0L,
- 0L,
- 0d);
-
- List<LabeledIntent> intents =
- mTemplateClassificationIntentFactory.create(
- InstrumentationRegistry.getContext(),
- TEXT,
- /* foreignText */ false,
- null,
- classificationResult);
-
- assertThat(intents).hasSize(1);
- LabeledIntent labeledIntent = intents.get(0);
- assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
- Intent intent = labeledIntent.intent;
- assertThat(intent.getAction()).isEqualTo(ACTION);
- }
-
- @Test
- public void create_nullTemplate() {
- AnnotatorModel.ClassificationResult classificationResult =
- new AnnotatorModel.ClassificationResult(
- TextClassifier.TYPE_ADDRESS,
- 1.0f,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- 0L,
- 0L,
- 0d);
-
- mTemplateClassificationIntentFactory.create(
- InstrumentationRegistry.getContext(),
- TEXT,
- /* foreignText */ false,
- null,
- classificationResult);
-
- verify(mFallback)
- .create(
- same(InstrumentationRegistry.getContext()),
- eq(TEXT),
- eq(false),
- eq(null),
- same(classificationResult));
- }
-
- @Test
- public void create_emptyResult() {
- AnnotatorModel.ClassificationResult classificationResult =
- new AnnotatorModel.ClassificationResult(
- TextClassifier.TYPE_ADDRESS,
- 1.0f,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- new RemoteActionTemplate[0],
- 0L,
- 0L,
- 0d);
-
- mTemplateClassificationIntentFactory.create(
- InstrumentationRegistry.getContext(),
- TEXT,
- /* foreignText */ false,
- null,
- classificationResult);
-
- verify(mFallback, never())
- .create(
- any(Context.class),
- eq(TEXT),
- eq(false),
- eq(null),
- any(AnnotatorModel.ClassificationResult.class));
- }
-
- private static RemoteActionTemplate[] createRemoteActionTemplates() {
- return new RemoteActionTemplate[] {
- new RemoteActionTemplate(
- TITLE_WITHOUT_ENTITY,
- null,
- DESCRIPTION,
- DESCRIPTION_WITH_APP_NAME,
- ACTION,
- null,
- null,
- null,
- null,
- null,
- null,
- null)
- };
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java b/java/tests/unittests/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java
deleted file mode 100644
index ed12784..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/intent/TemplateIntentFactoryTest.java
+++ /dev/null
@@ -1,269 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License
- */
-package com.android.textclassifier.intent;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.content.Intent;
-import android.net.Uri;
-
-import androidx.test.filters.SmallTest;
-import androidx.test.runner.AndroidJUnit4;
-
-import com.google.android.textclassifier.NamedVariant;
-import com.google.android.textclassifier.RemoteActionTemplate;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.MockitoAnnotations;
-
-import java.util.List;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public class TemplateIntentFactoryTest {
-
- private static final String TEXT = "text";
- private static final String TITLE_WITHOUT_ENTITY = "Map";
- private static final String TITLE_WITH_ENTITY = "Map NW14D1";
- private static final String DESCRIPTION = "Check the map";
- private static final String DESCRIPTION_WITH_APP_NAME = "Use %1$s to open map";
- private static final String ACTION = Intent.ACTION_VIEW;
- private static final String DATA = Uri.parse("http://www.android.com").toString();
- private static final String TYPE = "text/html";
- private static final Integer FLAG = Intent.FLAG_ACTIVITY_NEW_TASK;
- private static final String[] CATEGORY =
- new String[] {Intent.CATEGORY_DEFAULT, Intent.CATEGORY_APP_BROWSER};
- private static final String PACKAGE_NAME = "pkg.name";
- private static final String KEY_ONE = "key1";
- private static final String VALUE_ONE = "value1";
- private static final String KEY_TWO = "key2";
- private static final int VALUE_TWO = 42;
-
- private static final NamedVariant[] NAMED_VARIANTS =
- new NamedVariant[] {
- new NamedVariant(KEY_ONE, VALUE_ONE), new NamedVariant(KEY_TWO, VALUE_TWO)
- };
- private static final Integer REQUEST_CODE = 10;
-
- private TemplateIntentFactory mTemplateIntentFactory;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
- mTemplateIntentFactory = new TemplateIntentFactory();
- }
-
- @Test
- public void create_full() {
- RemoteActionTemplate remoteActionTemplate =
- new RemoteActionTemplate(
- TITLE_WITHOUT_ENTITY,
- TITLE_WITH_ENTITY,
- DESCRIPTION,
- DESCRIPTION_WITH_APP_NAME,
- ACTION,
- DATA,
- TYPE,
- FLAG,
- CATEGORY,
- /* packageName */ null,
- NAMED_VARIANTS,
- REQUEST_CODE);
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- assertThat(intents).hasSize(1);
- LabeledIntent labeledIntent = intents.get(0);
- assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
- assertThat(labeledIntent.titleWithEntity).isEqualTo(TITLE_WITH_ENTITY);
- assertThat(labeledIntent.description).isEqualTo(DESCRIPTION);
- assertThat(labeledIntent.descriptionWithAppName).isEqualTo(DESCRIPTION_WITH_APP_NAME);
- assertThat(labeledIntent.requestCode).isEqualTo(REQUEST_CODE);
- Intent intent = labeledIntent.intent;
- assertThat(intent.getAction()).isEqualTo(ACTION);
- assertThat(intent.getData().toString()).isEqualTo(DATA);
- assertThat(intent.getType()).isEqualTo(TYPE);
- assertThat(intent.getFlags()).isEqualTo(FLAG);
- assertThat(intent.getCategories()).containsExactly((Object[]) CATEGORY);
- assertThat(intent.getPackage()).isNull();
- assertThat(intent.getStringExtra(KEY_ONE)).isEqualTo(VALUE_ONE);
- assertThat(intent.getIntExtra(KEY_TWO, 0)).isEqualTo(VALUE_TWO);
- }
-
- @Test
- public void normalizesScheme() {
- RemoteActionTemplate remoteActionTemplate =
- new RemoteActionTemplate(
- TITLE_WITHOUT_ENTITY,
- TITLE_WITH_ENTITY,
- DESCRIPTION,
- DESCRIPTION_WITH_APP_NAME,
- ACTION,
- "HTTp://www.android.com",
- TYPE,
- FLAG,
- CATEGORY,
- /* packageName */ null,
- NAMED_VARIANTS,
- REQUEST_CODE);
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- String data = intents.get(0).intent.getData().toString();
- assertThat(data).isEqualTo("http://www.android.com");
- }
-
- @Test
- public void create_minimal() {
- RemoteActionTemplate remoteActionTemplate =
- new RemoteActionTemplate(
- TITLE_WITHOUT_ENTITY,
- null,
- DESCRIPTION,
- null,
- ACTION,
- null,
- null,
- null,
- null,
- null,
- null,
- null);
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- assertThat(intents).hasSize(1);
- LabeledIntent labeledIntent = intents.get(0);
- assertThat(labeledIntent.titleWithoutEntity).isEqualTo(TITLE_WITHOUT_ENTITY);
- assertThat(labeledIntent.titleWithEntity).isNull();
- assertThat(labeledIntent.description).isEqualTo(DESCRIPTION);
- assertThat(labeledIntent.requestCode).isEqualTo(LabeledIntent.DEFAULT_REQUEST_CODE);
- Intent intent = labeledIntent.intent;
- assertThat(intent.getAction()).isEqualTo(ACTION);
- assertThat(intent.getData()).isNull();
- assertThat(intent.getType()).isNull();
- assertThat(intent.getFlags()).isEqualTo(0);
- assertThat(intent.getCategories()).isNull();
- assertThat(intent.getPackage()).isNull();
- }
-
- @Test
- public void invalidTemplate_nullTemplate() {
- RemoteActionTemplate remoteActionTemplate = null;
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- assertThat(intents).isEmpty();
- }
-
- @Test
- public void invalidTemplate_nonEmptyPackageName() {
- RemoteActionTemplate remoteActionTemplate =
- new RemoteActionTemplate(
- TITLE_WITHOUT_ENTITY,
- TITLE_WITH_ENTITY,
- DESCRIPTION,
- DESCRIPTION_WITH_APP_NAME,
- ACTION,
- DATA,
- TYPE,
- FLAG,
- CATEGORY,
- PACKAGE_NAME,
- NAMED_VARIANTS,
- REQUEST_CODE);
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- assertThat(intents).isEmpty();
- }
-
- @Test
- public void invalidTemplate_emptyTitle() {
- RemoteActionTemplate remoteActionTemplate =
- new RemoteActionTemplate(
- null,
- null,
- DESCRIPTION,
- DESCRIPTION_WITH_APP_NAME,
- ACTION,
- null,
- null,
- null,
- null,
- null,
- null,
- null);
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- assertThat(intents).isEmpty();
- }
-
- @Test
- public void invalidTemplate_emptyDescription() {
- RemoteActionTemplate remoteActionTemplate =
- new RemoteActionTemplate(
- TITLE_WITHOUT_ENTITY,
- TITLE_WITH_ENTITY,
- null,
- null,
- ACTION,
- null,
- null,
- null,
- null,
- null,
- null,
- null);
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- assertThat(intents).isEmpty();
- }
-
- @Test
- public void invalidTemplate_emptyIntentAction() {
- RemoteActionTemplate remoteActionTemplate =
- new RemoteActionTemplate(
- TITLE_WITHOUT_ENTITY,
- TITLE_WITH_ENTITY,
- DESCRIPTION,
- DESCRIPTION_WITH_APP_NAME,
- null,
- null,
- null,
- null,
- null,
- null,
- null,
- null);
-
- List<LabeledIntent> intents =
- mTemplateIntentFactory.create(new RemoteActionTemplate[] {remoteActionTemplate});
-
- assertThat(intents).isEmpty();
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/logging/ResultIdUtilsTest.java b/java/tests/unittests/src/com/android/textclassifier/logging/ResultIdUtilsTest.java
deleted file mode 100644
index 3ec07bc..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/logging/ResultIdUtilsTest.java
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.logging;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.filters.SmallTest;
-
-import org.junit.Test;
-
-import java.util.Collections;
-import java.util.Locale;
-
-@SmallTest
-public class ResultIdUtilsTest {
- private static final int MODEL_VERSION = 703;
- private static final int HASH = 12345;
-
- @Test
- public void createId_customHash() {
- String resultId =
- ResultIdUtils.createId(
- MODEL_VERSION, Collections.singletonList(Locale.ENGLISH), HASH);
-
- assertThat(resultId).isEqualTo("androidtc|en_v703|12345");
- }
-
- @Test
- public void createId_selection() {
- String resultId =
- ResultIdUtils.createId(
- ApplicationProvider.getApplicationContext(),
- "text",
- 1,
- 2,
- MODEL_VERSION,
- Collections.singletonList(Locale.ENGLISH));
-
- assertThat(resultId).matches("androidtc\\|en_v703\\|-?\\d+");
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/logging/SelectionEventConverterTest.java b/java/tests/unittests/src/com/android/textclassifier/logging/SelectionEventConverterTest.java
deleted file mode 100644
index 2760145..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/logging/SelectionEventConverterTest.java
+++ /dev/null
@@ -1,199 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.logging;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.view.textclassifier.SelectionEvent;
-import android.view.textclassifier.TextClassification;
-import android.view.textclassifier.TextClassificationContext;
-import android.view.textclassifier.TextClassificationManager;
-import android.view.textclassifier.TextClassifier;
-import android.view.textclassifier.TextClassifierEvent;
-import android.view.textclassifier.TextSelection;
-
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.filters.SmallTest;
-
-import org.junit.Before;
-import org.junit.Test;
-
-import java.util.ArrayDeque;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.Locale;
-
-@SmallTest
-public class SelectionEventConverterTest {
- private static final String TEXT = "Some text here and there";
- private static final String PKG_NAME = "com.pkg";
- private static final String WIDGET_TYPE = TextClassifier.WIDGET_TYPE_EDITTEXT;
- private static final int START = 2;
- private static final int SMART_START = 1;
- private static final int SMART_END = 3;
- private TestTextClassifier mTestTextClassifier;
- private TextClassifier mSession;
-
- @Before
- public void setup() {
- TextClassificationManager textClassificationManager =
- ApplicationProvider.getApplicationContext()
- .getSystemService(TextClassificationManager.class);
- mTestTextClassifier = new TestTextClassifier();
- textClassificationManager.setTextClassifier(mTestTextClassifier);
- mSession = textClassificationManager.createTextClassificationSession(createEventContext());
- }
-
- @Test
- public void convert_started() {
- mSession.onSelectionEvent(
- SelectionEvent.createSelectionStartedEvent(
- SelectionEvent.INVOCATION_MANUAL, START));
-
- SelectionEvent interceptedEvent = mTestTextClassifier.popLastSelectionEvent();
- TextClassifierEvent textClassifierEvent =
- SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
-
- assertEventContext(textClassifierEvent.getEventContext());
- assertThat(textClassifierEvent.getEventIndex()).isEqualTo(0);
- assertThat(textClassifierEvent.getEventType())
- .isEqualTo(TextClassifierEvent.TYPE_SELECTION_STARTED);
- }
-
- @Test
- public void convert_smartSelection() {
- mSession.onSelectionEvent(
- SelectionEvent.createSelectionStartedEvent(
- SelectionEvent.INVOCATION_MANUAL, START));
- String resultId =
- ResultIdUtils.createId(
- 702, Collections.singletonList(Locale.ENGLISH), /*hash=*/ 12345);
- mSession.onSelectionEvent(
- SelectionEvent.createSelectionActionEvent(
- SMART_START,
- SMART_END,
- SelectionEvent.ACTION_SMART_SHARE,
- new TextClassification.Builder()
- .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
- .setId(resultId)
- .build()));
-
- SelectionEvent interceptedEvent = mTestTextClassifier.popLastSelectionEvent();
- TextClassifierEvent.TextSelectionEvent textSelectionEvent =
- (TextClassifierEvent.TextSelectionEvent)
- SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
-
- assertEventContext(textSelectionEvent.getEventContext());
- assertThat(textSelectionEvent.getRelativeWordStartIndex()).isEqualTo(-1);
- assertThat(textSelectionEvent.getRelativeWordEndIndex()).isEqualTo(1);
- assertThat(textSelectionEvent.getEventType())
- .isEqualTo(TextClassifierEvent.TYPE_SMART_ACTION);
- assertThat(textSelectionEvent.getEventIndex()).isEqualTo(1);
- assertThat(textSelectionEvent.getEntityTypes())
- .asList()
- .containsExactly(TextClassifier.TYPE_ADDRESS);
- assertThat(textSelectionEvent.getResultId()).isEqualTo(resultId);
- }
-
- @Test
- public void convert_smartShare() {
- mSession.onSelectionEvent(
- SelectionEvent.createSelectionStartedEvent(
- SelectionEvent.INVOCATION_MANUAL, START));
- String resultId =
- ResultIdUtils.createId(
- 702, Collections.singletonList(Locale.ENGLISH), /*hash=*/ 12345);
- mSession.onSelectionEvent(
- SelectionEvent.createSelectionModifiedEvent(
- SMART_START,
- SMART_END,
- new TextSelection.Builder(SMART_START, SMART_END)
- .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
- .setId(resultId)
- .build()));
-
- SelectionEvent interceptedEvent = mTestTextClassifier.popLastSelectionEvent();
- TextClassifierEvent.TextSelectionEvent textSelectionEvent =
- (TextClassifierEvent.TextSelectionEvent)
- SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
-
- assertEventContext(textSelectionEvent.getEventContext());
- assertThat(textSelectionEvent.getRelativeSuggestedWordStartIndex()).isEqualTo(-1);
- assertThat(textSelectionEvent.getRelativeSuggestedWordEndIndex()).isEqualTo(1);
- assertThat(textSelectionEvent.getEventType())
- .isEqualTo(TextClassifierEvent.TYPE_SMART_SELECTION_MULTI);
- assertThat(textSelectionEvent.getEventIndex()).isEqualTo(1);
- assertThat(textSelectionEvent.getEntityTypes())
- .asList()
- .containsExactly(TextClassifier.TYPE_ADDRESS);
- assertThat(textSelectionEvent.getResultId()).isEqualTo(resultId);
- }
-
- @Test
- public void convert_smartLinkify() {
- mSession.onSelectionEvent(
- SelectionEvent.createSelectionStartedEvent(SelectionEvent.INVOCATION_LINK, START));
- String resultId =
- ResultIdUtils.createId(
- 702, Collections.singletonList(Locale.ENGLISH), /*hash=*/ 12345);
- mSession.onSelectionEvent(
- SelectionEvent.createSelectionModifiedEvent(
- SMART_START,
- SMART_END,
- new TextSelection.Builder(SMART_START, SMART_END)
- .setEntityType(TextClassifier.TYPE_ADDRESS, 1.0f)
- .setId(resultId)
- .build()));
-
- SelectionEvent interceptedEvent = mTestTextClassifier.popLastSelectionEvent();
- TextClassifierEvent.TextLinkifyEvent textLinkifyEvent =
- (TextClassifierEvent.TextLinkifyEvent)
- SelectionEventConverter.toTextClassifierEvent(interceptedEvent);
-
- assertEventContext(textLinkifyEvent.getEventContext());
- assertThat(textLinkifyEvent.getEventType())
- .isEqualTo(TextClassifierEvent.TYPE_SMART_SELECTION_MULTI);
- assertThat(textLinkifyEvent.getEventIndex()).isEqualTo(1);
- assertThat(textLinkifyEvent.getEntityTypes())
- .asList()
- .containsExactly(TextClassifier.TYPE_ADDRESS);
- assertThat(textLinkifyEvent.getResultId()).isEqualTo(resultId);
- }
-
- private static TextClassificationContext createEventContext() {
- return new TextClassificationContext.Builder(PKG_NAME, TextClassifier.WIDGET_TYPE_EDITTEXT)
- .build();
- }
-
- private static void assertEventContext(TextClassificationContext eventContext) {
- assertThat(eventContext.getPackageName()).isEqualTo(PKG_NAME);
- assertThat(eventContext.getWidgetType()).isEqualTo(WIDGET_TYPE);
- }
-
- private static class TestTextClassifier implements TextClassifier {
- private Deque<SelectionEvent> mSelectionEvents = new ArrayDeque<>();
-
- @Override
- public void onSelectionEvent(SelectionEvent event) {
- mSelectionEvents.push(event);
- }
-
- SelectionEvent popLastSelectionEvent() {
- return mSelectionEvents.pop();
- }
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/subjects/EntitySubject.java b/java/tests/unittests/src/com/android/textclassifier/subjects/EntitySubject.java
deleted file mode 100644
index 17a339f..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/subjects/EntitySubject.java
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.subjects;
-
-import static com.google.common.truth.Truth.assertAbout;
-
-import com.android.textclassifier.Entity;
-
-import com.google.common.truth.FailureMetadata;
-import com.google.common.truth.MathUtil;
-import com.google.common.truth.Subject;
-
-import javax.annotation.Nullable;
-
-/** Test helper for checking {@link com.android.textclassifier.Entity} results. */
-public final class EntitySubject extends Subject<EntitySubject, Entity> {
-
- private static final float TOLERANCE = 0.0001f;
- private Entity mEntity;
-
- public static EntitySubject assertThat(@Nullable Entity entity) {
- return assertAbout(EntitySubject::new).that(entity);
- }
-
- private EntitySubject(FailureMetadata failureMetadata, @Nullable Entity entity) {
- super(failureMetadata, entity);
- mEntity = entity;
- }
-
- public void isMatchWithinTolerance(@Nullable Entity entity) {
- if (!entity.getEntityType().equals(mEntity.getEntityType())) {
- failWithActual("expected to have type", entity.getEntityType());
- }
- if (!MathUtil.equalWithinTolerance(entity.getScore(), mEntity.getScore(), TOLERANCE)) {
- failWithActual("expected to have confidence score", entity.getScore());
- }
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzerTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzerTest.java
deleted file mode 100644
index 571abbd..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzerTest.java
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-import android.content.Context;
-
-import androidx.room.Room;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.filters.SmallTest;
-
-import com.android.textclassifier.TextClassificationConstants;
-import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
-import com.android.textclassifier.ulp.database.LanguageSignalInfo;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.Arrays;
-import java.util.Locale;
-
-/** Testing {@link BasicLanguageProficiencyAnalyzer} in an in-memory database. */
-@SmallTest
-public class BasicLanguageProficiencyAnalyzerTest {
-
- private static final String PRIMARY_SYSTEM_LANGUAGE = Locale.CHINESE.toLanguageTag();
- private static final String SECONDARY_SYSTEM_LANGUAGE = Locale.ENGLISH.toLanguageTag();
- private static final String NON_SYSTEM_LANGUAGE = Locale.JAPANESE.toLanguageTag();
-
- private LanguageProfileDatabase mDatabase;
- private BasicLanguageProficiencyAnalyzer mProficiencyAnalyzer;
- @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
-
- Context context = ApplicationProvider.getApplicationContext();
- TextClassificationConstants textClassificationConstants =
- mock(TextClassificationConstants.class);
- mDatabase = Room.inMemoryDatabaseBuilder(context, LanguageProfileDatabase.class).build();
- mProficiencyAnalyzer =
- new BasicLanguageProficiencyAnalyzer(
- textClassificationConstants, mDatabase, mSystemLanguagesProvider);
- when(mSystemLanguagesProvider.getSystemLanguageTags())
- .thenReturn(Arrays.asList(PRIMARY_SYSTEM_LANGUAGE, SECONDARY_SYSTEM_LANGUAGE));
- when(textClassificationConstants.getLanguageProficiencyBootstrappingCount())
- .thenReturn(100);
- }
-
- @After
- public void close() {
- mDatabase.close();
- }
-
- @Test
- public void canUnderstand_emptyDatabase() {
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
- assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.5f);
- assertThat(mProficiencyAnalyzer.canUnderstand(NON_SYSTEM_LANGUAGE)).isEqualTo(0f);
- }
-
- @Test
- public void canUnderstand_validRequest() {
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- PRIMARY_SYSTEM_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 100));
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- SECONDARY_SYSTEM_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 30));
-
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
- assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.4f);
- assertThat(mProficiencyAnalyzer.canUnderstand(NON_SYSTEM_LANGUAGE)).isEqualTo(0f);
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzerTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzerTest.java
deleted file mode 100644
index cb3fc02..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzerTest.java
+++ /dev/null
@@ -1,152 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-import android.content.Context;
-
-import androidx.room.Room;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.filters.SmallTest;
-
-import com.android.textclassifier.TextClassificationConstants;
-import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
-import com.android.textclassifier.ulp.database.LanguageSignalInfo;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Locale;
-
-/** Testing {@link KmeansLanguageProficiencyAnalyzer} in an in-memory database. */
-@SmallTest
-public class KmeansLanguageProficiencyAnalyzerTest {
-
- private static final String PRIMARY_SYSTEM_LANGUAGE = Locale.CHINESE.toLanguageTag();
- private static final String SECONDARY_SYSTEM_LANGUAGE = Locale.ENGLISH.toLanguageTag();
- private static final String NORMAL_LANGUAGE = Locale.JAPANESE.toLanguageTag();
-
- private LanguageProfileDatabase mDatabase;
- private KmeansLanguageProficiencyAnalyzer mProficiencyAnalyzer;
- @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
-
- Context context = ApplicationProvider.getApplicationContext();
- TextClassificationConstants textClassificationConstants =
- mock(TextClassificationConstants.class);
- mDatabase = Room.inMemoryDatabaseBuilder(context, LanguageProfileDatabase.class).build();
- mProficiencyAnalyzer =
- new KmeansLanguageProficiencyAnalyzer(
- textClassificationConstants, mDatabase, mSystemLanguagesProvider);
- when(mSystemLanguagesProvider.getSystemLanguageTags())
- .thenReturn(Arrays.asList(PRIMARY_SYSTEM_LANGUAGE, SECONDARY_SYSTEM_LANGUAGE));
- when(textClassificationConstants.getLanguageProficiencyBootstrappingCount())
- .thenReturn(100);
- }
-
- @After
- public void close() {
- mDatabase.close();
- }
-
- @Test
- public void canUnderstand_emptyDatabase() {
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
- assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.5f);
- assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
- }
-
- @Test
- public void canUnderstand_oneLanguage() {
- when(mSystemLanguagesProvider.getSystemLanguageTags())
- .thenReturn(Collections.singletonList(PRIMARY_SYSTEM_LANGUAGE));
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- PRIMARY_SYSTEM_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 1));
-
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
- assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0f);
- assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
- }
-
- @Test
- public void canUnderstand_twoLanguages() {
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- PRIMARY_SYSTEM_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 100));
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- SECONDARY_SYSTEM_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 50));
-
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
- assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(0.5f);
- assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
- }
-
- @Test
- public void canUnderstand_threeLanguages() {
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- PRIMARY_SYSTEM_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 100));
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- SECONDARY_SYSTEM_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 75));
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(
- new LanguageSignalInfo(
- NORMAL_LANGUAGE,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 2));
-
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
- assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(1f);
- assertThat(mProficiencyAnalyzer.canUnderstand(NORMAL_LANGUAGE)).isEqualTo(0f);
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluatorTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluatorTest.java
deleted file mode 100644
index 387cb61..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluatorTest.java
+++ /dev/null
@@ -1,165 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.Mockito.when;
-
-import androidx.test.filters.SmallTest;
-
-import com.google.android.collect.Sets;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.Mockito;
-import org.mockito.MockitoAnnotations;
-import org.mockito.stubbing.Answer;
-
-import java.util.Arrays;
-import java.util.Set;
-
-@SmallTest
-public class LanguageProficiencyEvaluatorTest {
- private static final float EPSILON = 0.01f;
- private LanguageProficiencyEvaluator mLanguageProficiencyEvaluator;
-
- @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
-
- private static final String SYSTEM_LANGUAGE_EN = "en";
- private static final String SYSTEM_LANGUAGE_ZH = "zh";
- private static final String NORMAL_LANGUAGE_JP = "jp";
- private static final String NORMAL_LANGUAGE_FR = "fr";
- private static final String NORMAL_LANGUAGE_PL = "pl";
- private static final Set<String> EVALUATION_LANGUAGES =
- Sets.newArraySet(
- SYSTEM_LANGUAGE_EN,
- SYSTEM_LANGUAGE_ZH,
- NORMAL_LANGUAGE_JP,
- NORMAL_LANGUAGE_FR,
- NORMAL_LANGUAGE_PL);
-
- @Mock private LanguageProficiencyAnalyzer mLanguageProficiencyAnalyzer;
-
- @Before
- public void setUp() {
- MockitoAnnotations.initMocks(this);
- when(mSystemLanguagesProvider.getSystemLanguageTags())
- .thenReturn(Arrays.asList(SYSTEM_LANGUAGE_EN, SYSTEM_LANGUAGE_ZH));
- mLanguageProficiencyEvaluator = new LanguageProficiencyEvaluator(mSystemLanguagesProvider);
- }
-
- @Test
- public void evaluate_allCorrect() {
- when(mLanguageProficiencyAnalyzer.canUnderstand(Mockito.anyString()))
- .thenAnswer(
- (Answer<Float>)
- invocation -> {
- String languageTag = invocation.getArgument(0);
- if (languageTag.equals(SYSTEM_LANGUAGE_EN)
- || languageTag.equals(SYSTEM_LANGUAGE_ZH)) {
- return 1f;
- }
- return 0f;
- });
-
- LanguageProficiencyEvaluator.EvaluationResult evaluationResult =
- mLanguageProficiencyEvaluator.evaluate(
- mLanguageProficiencyAnalyzer, EVALUATION_LANGUAGES);
-
- assertThat(evaluationResult.truePositive).isEqualTo(2);
- assertThat(evaluationResult.trueNegative).isEqualTo(3);
- assertThat(evaluationResult.falsePositive).isEqualTo(0);
- assertThat(evaluationResult.falseNegative).isEqualTo(0);
- assertThat(evaluationResult.computePrecisionOfPositiveClass()).isWithin(EPSILON).of(1f);
- assertThat(evaluationResult.computePrecisionOfNegativeClass()).isWithin(EPSILON).of(1f);
- assertThat(evaluationResult.computeRecallOfPositiveClass()).isWithin(EPSILON).of(1f);
- assertThat(evaluationResult.computeRecallOfNegativeClass()).isWithin(EPSILON).of(1f);
- assertThat(evaluationResult.computeF1ScoreOfPositiveClass()).isWithin(EPSILON).of(1f);
- assertThat(evaluationResult.computeF1ScoreOfNegativeClass()).isWithin(EPSILON).of(1f);
- }
-
- @Test
- public void evaluate_allWrong() {
- when(mLanguageProficiencyAnalyzer.canUnderstand(Mockito.anyString()))
- .thenAnswer(
- (Answer<Float>)
- invocation -> {
- String languageTag = invocation.getArgument(0);
- if (languageTag.equals(SYSTEM_LANGUAGE_EN)
- || languageTag.equals(SYSTEM_LANGUAGE_ZH)) {
- return 0f;
- }
- return 1f;
- });
-
- LanguageProficiencyEvaluator.EvaluationResult evaluationResult =
- mLanguageProficiencyEvaluator.evaluate(
- mLanguageProficiencyAnalyzer, EVALUATION_LANGUAGES);
-
- assertThat(evaluationResult.truePositive).isEqualTo(0);
- assertThat(evaluationResult.trueNegative).isEqualTo(0);
- assertThat(evaluationResult.falsePositive).isEqualTo(3);
- assertThat(evaluationResult.falseNegative).isEqualTo(2);
- assertThat(evaluationResult.computePrecisionOfPositiveClass()).isWithin(EPSILON).of(0f);
- assertThat(evaluationResult.computePrecisionOfNegativeClass()).isWithin(EPSILON).of(0f);
- assertThat(evaluationResult.computeRecallOfPositiveClass()).isWithin(EPSILON).of(0f);
- assertThat(evaluationResult.computeRecallOfNegativeClass()).isWithin(EPSILON).of(0f);
- assertThat(evaluationResult.computeF1ScoreOfPositiveClass()).isWithin(EPSILON).of(0f);
- assertThat(evaluationResult.computeF1ScoreOfNegativeClass()).isWithin(EPSILON).of(0f);
- }
-
- @Test
- public void evaluate_mixed() {
- when(mLanguageProficiencyAnalyzer.canUnderstand(Mockito.anyString()))
- .thenAnswer(
- (Answer<Float>)
- invocation -> {
- String languageTag = invocation.getArgument(0);
- switch (languageTag) {
- case SYSTEM_LANGUAGE_EN:
- return 1f;
- case SYSTEM_LANGUAGE_ZH:
- return 0f;
- case NORMAL_LANGUAGE_FR:
- return 0f;
- case NORMAL_LANGUAGE_JP:
- return 0f;
- case NORMAL_LANGUAGE_PL:
- return 1f;
- }
- throw new IllegalArgumentException(
- "unexpected language: " + languageTag);
- });
-
- LanguageProficiencyEvaluator.EvaluationResult evaluationResult =
- mLanguageProficiencyEvaluator.evaluate(
- mLanguageProficiencyAnalyzer, EVALUATION_LANGUAGES);
-
- assertThat(evaluationResult.truePositive).isEqualTo(1);
- assertThat(evaluationResult.trueNegative).isEqualTo(2);
- assertThat(evaluationResult.falsePositive).isEqualTo(1);
- assertThat(evaluationResult.falseNegative).isEqualTo(1);
- assertThat(evaluationResult.computePrecisionOfPositiveClass()).isWithin(EPSILON).of(0.5f);
- assertThat(evaluationResult.computePrecisionOfNegativeClass()).isWithin(EPSILON).of(0.66f);
- assertThat(evaluationResult.computeRecallOfPositiveClass()).isWithin(EPSILON).of(0.5f);
- assertThat(evaluationResult.computeRecallOfNegativeClass()).isWithin(EPSILON).of(0.66f);
- assertThat(evaluationResult.computeF1ScoreOfPositiveClass()).isWithin(EPSILON).of(0.5f);
- assertThat(evaluationResult.computeF1ScoreOfNegativeClass()).isWithin(EPSILON).of(0.66f);
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProfileAnalyzerTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProfileAnalyzerTest.java
deleted file mode 100644
index 8b21a9f..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProfileAnalyzerTest.java
+++ /dev/null
@@ -1,145 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.ArgumentMatchers.anyString;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
-
-import android.content.Context;
-
-import androidx.room.Room;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.filters.SmallTest;
-
-import com.android.textclassifier.Entity;
-import com.android.textclassifier.TextClassificationConstants;
-import com.android.textclassifier.subjects.EntitySubject;
-import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
-import com.android.textclassifier.ulp.database.LanguageSignalInfo;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-/** Testing {@link LanguageProfileAnalyzer} in an inMemoryDatabase. */
-@SmallTest
-public class LanguageProfileAnalyzerTest {
-
- private static final String SYSTEM_LANGUAGE_CODE = "en";
- private static final String LOCATION_LANGUAGE_CODE = "jp";
- private static final String NORMAL_LANGUAGE_CODE = "pl";
-
- private LanguageProfileDatabase mDatabase;
- private LanguageProfileAnalyzer mLanguageProfileAnalyzer;
- @Mock private LocationSignalProvider mLocationSignalProvider;
- @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
- @Mock private LanguageProficiencyAnalyzer mLanguageProficiencyAnalyzer;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
-
- Context mContext = ApplicationProvider.getApplicationContext();
- mDatabase = Room.inMemoryDatabaseBuilder(mContext, LanguageProfileDatabase.class).build();
- when(mLocationSignalProvider.detectLanguageTag()).thenReturn(LOCATION_LANGUAGE_CODE);
- when(mSystemLanguagesProvider.getSystemLanguageTags())
- .thenReturn(Collections.singletonList(SYSTEM_LANGUAGE_CODE));
- when(mLanguageProficiencyAnalyzer.canUnderstand(anyString())).thenReturn(1.0f);
- TextClassificationConstants customTextClassificationConstants =
- mock(TextClassificationConstants.class);
- when(customTextClassificationConstants.getFrequentLanguagesBootstrappingCount())
- .thenReturn(100);
- mLanguageProfileAnalyzer =
- new LanguageProfileAnalyzer(
- mContext,
- customTextClassificationConstants,
- mDatabase,
- mLanguageProficiencyAnalyzer,
- mLocationSignalProvider,
- mSystemLanguagesProvider);
- }
-
- @After
- public void close() {
- mDatabase.close();
- }
-
- @Test
- public void getFrequentLanguages_emptyDatabase() {
- List<Entity> frequentLanguages =
- mLanguageProfileAnalyzer.getFrequentLanguages(LanguageSignalInfo.CLASSIFY_TEXT);
-
- assertThat(frequentLanguages).hasSize(2);
- EntitySubject.assertThat(frequentLanguages.get(0))
- .isMatchWithinTolerance(new Entity(SYSTEM_LANGUAGE_CODE, 1.0f));
- EntitySubject.assertThat(frequentLanguages.get(1))
- .isMatchWithinTolerance(new Entity(LOCATION_LANGUAGE_CODE, 1.0f));
- }
-
- @Test
- public void getFrequentLanguages_mixedSignal() {
- insertSignal(NORMAL_LANGUAGE_CODE, LanguageSignalInfo.CLASSIFY_TEXT, 50);
- insertSignal(SYSTEM_LANGUAGE_CODE, LanguageSignalInfo.CLASSIFY_TEXT, 100);
- // Unrelated signals.
- insertSignal(NORMAL_LANGUAGE_CODE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100);
- insertSignal(SYSTEM_LANGUAGE_CODE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100);
- insertSignal(LOCATION_LANGUAGE_CODE, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 100);
-
- List<Entity> frequentLanguages =
- mLanguageProfileAnalyzer.getFrequentLanguages(LanguageSignalInfo.CLASSIFY_TEXT);
-
- assertThat(frequentLanguages).hasSize(3);
- EntitySubject.assertThat(frequentLanguages.get(0))
- .isMatchWithinTolerance(new Entity(SYSTEM_LANGUAGE_CODE, 1.0f));
- EntitySubject.assertThat(frequentLanguages.get(1))
- .isMatchWithinTolerance(new Entity(LOCATION_LANGUAGE_CODE, 0.5f));
- EntitySubject.assertThat(frequentLanguages.get(2))
- .isMatchWithinTolerance(new Entity(NORMAL_LANGUAGE_CODE, 0.25f));
- }
-
- @Test
- public void getFrequentLanguages_bothSystemLanguageAndLocationLanguage() {
- when(mLocationSignalProvider.detectLanguageTag()).thenReturn("en");
- when(mSystemLanguagesProvider.getSystemLanguageTags())
- .thenReturn(Arrays.asList("en", "jp"));
-
- List<Entity> frequentLanguages =
- mLanguageProfileAnalyzer.getFrequentLanguages(LanguageSignalInfo.CLASSIFY_TEXT);
-
- assertThat(frequentLanguages).hasSize(2);
- EntitySubject.assertThat(frequentLanguages.get(0))
- .isMatchWithinTolerance(new Entity("en", 1.0f));
- EntitySubject.assertThat(frequentLanguages.get(1))
- .isMatchWithinTolerance(new Entity("jp", 0.5f));
- }
-
- private void insertSignal(
- String languageTag, @LanguageSignalInfo.Source int source, int count) {
- mDatabase
- .languageInfoDao()
- .insertLanguageInfo(new LanguageSignalInfo(languageTag, source, count));
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProfileUpdaterTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProfileUpdaterTest.java
deleted file mode 100644
index b90704a..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/LanguageProfileUpdaterTest.java
+++ /dev/null
@@ -1,230 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.app.Person;
-import android.content.Context;
-import android.os.Bundle;
-import android.view.textclassifier.ConversationActions;
-import android.view.textclassifier.TextClassification;
-
-import androidx.room.Room;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.filters.SmallTest;
-
-import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
-import com.android.textclassifier.ulp.database.LanguageSignalInfo;
-
-import com.google.common.collect.ImmutableList;
-import com.google.common.util.concurrent.ListeningExecutorService;
-import com.google.common.util.concurrent.MoreExecutors;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-
-import java.time.ZoneId;
-import java.time.ZonedDateTime;
-import java.util.Arrays;
-import java.util.List;
-import java.util.Locale;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executors;
-import java.util.function.Function;
-
-/** Testing {@link LanguageProfileUpdater} in an in-memory database. */
-@SmallTest
-public class LanguageProfileUpdaterTest {
-
- private static final String NOTIFICATION_KEY = "test_notification";
- private static final String LOCALE_TAG_US = Locale.US.toLanguageTag();
- private static final String LOCALE_TAG_CHINA = Locale.CHINA.toLanguageTag();
- private static final String TEXT_ONE = "hello world";
- private static final String TEXT_TWO = "你好!";
- private static final Function<CharSequence, List<String>> LANGUAGE_DETECTOR_US =
- charSequence -> ImmutableList.of(LOCALE_TAG_US);
- private static final Function<CharSequence, List<String>> LANGUAGE_DETECTOR_CHINA =
- charSequence -> ImmutableList.of(LOCALE_TAG_CHINA);
- private static final Person PERSON = new Person.Builder().build();
- private static final ZonedDateTime TIME_ONE =
- ZonedDateTime.of(2019, 7, 21, 12, 12, 12, 12, ZoneId.systemDefault());
- private static final ZonedDateTime TIME_TWO =
- ZonedDateTime.of(2019, 7, 21, 12, 20, 20, 12, ZoneId.systemDefault());
- private static final ConversationActions.Message MSG_ONE =
- new ConversationActions.Message.Builder(PERSON)
- .setReferenceTime(TIME_ONE)
- .setText(TEXT_ONE)
- .setExtras(new Bundle())
- .build();
- private static final ConversationActions.Message MSG_TWO =
- new ConversationActions.Message.Builder(PERSON)
- .setReferenceTime(TIME_TWO)
- .setText("where are you?")
- .setExtras(new Bundle())
- .build();
- private static final ConversationActions.Message MSG_THREE =
- new ConversationActions.Message.Builder(PERSON)
- .setReferenceTime(TIME_TWO)
- .setText(TEXT_TWO)
- .setExtras(new Bundle())
- .build();
- private static final ConversationActions.Request CONVERSATION_ACTION_REQUEST_ONE =
- new ConversationActions.Request.Builder(Arrays.asList(MSG_ONE)).build();
- private static final ConversationActions.Request CONVERSATION_ACTION_REQUEST_TWO =
- new ConversationActions.Request.Builder(Arrays.asList(MSG_TWO)).build();
- private static final TextClassification.Request TEXT_CLASSIFICATION_REQUEST_ONE =
- new TextClassification.Request.Builder(TEXT_ONE, 0, 2).build();
- private static final LanguageSignalInfo US_INFO_ONE_FOR_CONVERSATION_ACTION_ONE =
- new LanguageSignalInfo(
- LOCALE_TAG_US, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1);
- private static final LanguageSignalInfo US_INFO_ONE_FOR_CONVERSATION_ACTION_TWO =
- new LanguageSignalInfo(
- LOCALE_TAG_US, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 2);
- private static final LanguageSignalInfo US_INFO_ONE_FOR_CLASSIFY_TEXT =
- new LanguageSignalInfo(LOCALE_TAG_US, LanguageSignalInfo.CLASSIFY_TEXT, 1);
-
- private LanguageProfileUpdater mLanguageProfileUpdater;
- private LanguageProfileDatabase mDatabase;
-
- @Before
- public void setup() {
- Context mContext = ApplicationProvider.getApplicationContext();
- ListeningExecutorService mExecutorService =
- MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor());
- mDatabase = Room.inMemoryDatabaseBuilder(mContext, LanguageProfileDatabase.class).build();
- mLanguageProfileUpdater = new LanguageProfileUpdater(mExecutorService, mDatabase);
- }
-
- @After
- public void close() {
- mDatabase.close();
- }
-
- @Test
- public void updateFromConversationActionsAsync_oneMessage()
- throws ExecutionException, InterruptedException {
- mLanguageProfileUpdater
- .updateFromConversationActionsAsync(
- CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
- .get();
- List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
-
- assertThat(infos).hasSize(1);
- LanguageSignalInfo info = infos.get(0);
- assertThat(info).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_ONE);
- }
-
- /** Notification keys for these two messages are DEFAULT_NOTIFICATION_KEY */
- @Test
- public void updateFromConversationActionsAsync_twoMessagesInSameNotificationWithSameLanguage()
- throws ExecutionException, InterruptedException {
- mLanguageProfileUpdater
- .updateFromConversationActionsAsync(
- CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
- .get();
- mLanguageProfileUpdater
- .updateFromConversationActionsAsync(
- CONVERSATION_ACTION_REQUEST_TWO, LANGUAGE_DETECTOR_US)
- .get();
- List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
-
- assertThat(infos).hasSize(1);
- LanguageSignalInfo info = infos.get(0);
- assertThat(info).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_TWO);
- }
-
- @Test
- public void updateFromConversationActionsAsync_twoMessagesInDifferentNotifications()
- throws ExecutionException, InterruptedException {
- mLanguageProfileUpdater
- .updateFromConversationActionsAsync(
- CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
- .get();
- Bundle extra = new Bundle();
- extra.putString(LanguageProfileUpdater.NOTIFICATION_KEY, NOTIFICATION_KEY);
- ConversationActions.Request newRequest =
- new ConversationActions.Request.Builder(Arrays.asList(MSG_TWO))
- .setExtras(extra)
- .build();
- mLanguageProfileUpdater
- .updateFromConversationActionsAsync(newRequest, LANGUAGE_DETECTOR_US)
- .get();
- List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
-
- assertThat(infos).hasSize(1);
- LanguageSignalInfo info = infos.get(0);
- assertThat(info).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_TWO);
- }
-
- @Test
- public void updateFromConversationActionsAsync_twoMessagesInDifferentLanguage()
- throws ExecutionException, InterruptedException {
- mLanguageProfileUpdater
- .updateFromConversationActionsAsync(
- CONVERSATION_ACTION_REQUEST_ONE, LANGUAGE_DETECTOR_US)
- .get();
- ConversationActions.Request newRequest =
- new ConversationActions.Request.Builder(Arrays.asList(MSG_THREE)).build();
- mLanguageProfileUpdater
- .updateFromConversationActionsAsync(newRequest, LANGUAGE_DETECTOR_CHINA)
- .get();
- List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
-
- assertThat(infos).hasSize(2);
- LanguageSignalInfo infoOne = infos.get(0);
- LanguageSignalInfo infoTwo = infos.get(1);
- assertThat(infoOne).isEqualTo(US_INFO_ONE_FOR_CONVERSATION_ACTION_ONE);
- assertThat(infoTwo)
- .isEqualTo(
- new LanguageSignalInfo(
- LOCALE_TAG_CHINA,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS,
- 1));
- }
-
- @Test
- public void updateFromClassifyTextAsync_classifyText()
- throws ExecutionException, InterruptedException {
- mLanguageProfileUpdater.updateFromClassifyTextAsync(ImmutableList.of(LOCALE_TAG_US)).get();
- List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
-
- assertThat(infos).hasSize(1);
- LanguageSignalInfo info = infos.get(0);
- assertThat(info).isEqualTo(US_INFO_ONE_FOR_CLASSIFY_TEXT);
- }
-
- @Test
- public void updateFromClassifyTextAsync_classifyTextTwice()
- throws ExecutionException, InterruptedException {
- mLanguageProfileUpdater.updateFromClassifyTextAsync(ImmutableList.of(LOCALE_TAG_US)).get();
- mLanguageProfileUpdater
- .updateFromClassifyTextAsync(ImmutableList.of(LOCALE_TAG_CHINA))
- .get();
-
- List<LanguageSignalInfo> infos = mDatabase.languageInfoDao().getAll();
- assertThat(infos).hasSize(2);
- LanguageSignalInfo infoOne = infos.get(0);
- LanguageSignalInfo infoTwo = infos.get(1);
- assertThat(infoOne).isEqualTo(US_INFO_ONE_FOR_CLASSIFY_TEXT);
- assertThat(infoTwo)
- .isEqualTo(
- new LanguageSignalInfo(
- LOCALE_TAG_CHINA, LanguageSignalInfo.CLASSIFY_TEXT, 1));
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/LocationSignalProviderTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/LocationSignalProviderTest.java
deleted file mode 100644
index 0979344..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/LocationSignalProviderTest.java
+++ /dev/null
@@ -1,75 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.Mockito.when;
-
-import android.location.Address;
-import android.location.Geocoder;
-import android.location.Location;
-import android.location.LocationManager;
-import android.telephony.TelephonyManager;
-
-import androidx.test.filters.SmallTest;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.Mockito;
-import org.mockito.MockitoAnnotations;
-
-import java.io.IOException;
-import java.util.Collections;
-import java.util.Locale;
-
-@SmallTest
-public class LocationSignalProviderTest {
- @Mock private LocationManager mLocationManager;
- @Mock private TelephonyManager mTelephonyManager;
- @Mock private LocationSignalProvider mLocationSignalProvider;
- @Mock private Geocoder mGeocoder;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
- mLocationSignalProvider =
- new LocationSignalProvider(mLocationManager, mTelephonyManager, mGeocoder);
- }
-
- @Test
- public void detectLanguageTag_useTelephony() {
- when(mTelephonyManager.getNetworkCountryIso()).thenReturn(Locale.UK.getCountry());
-
- assertThat(mLocationSignalProvider.detectLanguageTag()).isEqualTo("en");
- }
-
- @Test
- public void detectLanguageTag_useLocation() throws IOException {
- when(mTelephonyManager.getNetworkCountryIso()).thenReturn(null);
- Location location = new Location(LocationManager.PASSIVE_PROVIDER);
- when(mLocationManager.getLastKnownLocation(LocationManager.PASSIVE_PROVIDER))
- .thenReturn(location);
- Address address = new Address(Locale.FRANCE);
- address.setCountryCode(Locale.FRANCE.getCountry());
- when(mGeocoder.getFromLocation(Mockito.anyDouble(), Mockito.anyDouble(), Mockito.anyInt()))
- .thenReturn(Collections.singletonList(address));
-
- assertThat(mLocationSignalProvider.detectLanguageTag()).isEqualTo("fr");
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzerTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzerTest.java
deleted file mode 100644
index 8e43e99..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzerTest.java
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import static org.mockito.Mockito.when;
-
-import android.content.Context;
-import android.content.SharedPreferences;
-import android.view.textclassifier.TextClassifierEvent;
-
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.filters.SmallTest;
-
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-import java.util.Arrays;
-import java.util.Locale;
-
-/** Testing {@link ReinforcementLanguageProficiencyAnalyzer} using Mockito. */
-@SmallTest
-public class ReinforcementLanguageProficiencyAnalyzerTest {
-
- private static final String PRIMARY_SYSTEM_LANGUAGE = Locale.CHINESE.toLanguageTag();
- private static final String SECONDARY_SYSTEM_LANGUAGE = Locale.ENGLISH.toLanguageTag();
- private static final String NON_SYSTEM_LANGUAGE = Locale.JAPANESE.toLanguageTag();
- private ReinforcementLanguageProficiencyAnalyzer mProficiencyAnalyzer;
- @Mock private SystemLanguagesProvider mSystemLanguagesProvider;
- private SharedPreferences mSharedPreferences;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
- Context context = ApplicationProvider.getApplicationContext();
- mSharedPreferences = context.getSharedPreferences("test-preferences", Context.MODE_PRIVATE);
- when(mSystemLanguagesProvider.getSystemLanguageTags())
- .thenReturn(Arrays.asList(PRIMARY_SYSTEM_LANGUAGE, SECONDARY_SYSTEM_LANGUAGE));
- mProficiencyAnalyzer =
- new ReinforcementLanguageProficiencyAnalyzer(
- mSystemLanguagesProvider, mSharedPreferences);
- }
-
- @After
- public void teardown() {
- mSharedPreferences.edit().clear().apply();
- }
-
- @Test
- public void canUnderstand_defaultValue() {
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(1.0f);
- assertThat(mProficiencyAnalyzer.canUnderstand(SECONDARY_SYSTEM_LANGUAGE)).isEqualTo(1.0f);
- assertThat(mProficiencyAnalyzer.canUnderstand(NON_SYSTEM_LANGUAGE)).isEqualTo(0f);
- }
-
- @Test
- public void canUnderstand_enoughFeedback() {
- sendEvent(TextClassifierEvent.TYPE_ACTIONS_SHOWN, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 50);
- sendEvent(TextClassifierEvent.TYPE_SMART_ACTION, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 40);
-
- assertThat(mProficiencyAnalyzer.canUnderstand(PRIMARY_SYSTEM_LANGUAGE)).isEqualTo(0.8f);
- }
-
- @Test
- public void shouldShowTranslation_defaultValue() {
- assertThat(mProficiencyAnalyzer.shouldShowTranslation(PRIMARY_SYSTEM_LANGUAGE))
- .isEqualTo(true);
- assertThat(mProficiencyAnalyzer.shouldShowTranslation(SECONDARY_SYSTEM_LANGUAGE))
- .isEqualTo(true);
- assertThat(mProficiencyAnalyzer.shouldShowTranslation(NON_SYSTEM_LANGUAGE)).isEqualTo(true);
- }
-
- @Test
- public void shouldShowTranslation_enoughFeedback_true() {
- sendEvent(
- TextClassifierEvent.TYPE_ACTIONS_SHOWN, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 1000);
- sendEvent(TextClassifierEvent.TYPE_SMART_ACTION, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 200);
-
- assertThat(mProficiencyAnalyzer.shouldShowTranslation(PRIMARY_SYSTEM_LANGUAGE))
- .isEqualTo(true);
- }
-
- @Test
- public void shouldShowTranslation_enoughFeedback_false() {
- sendEvent(
- TextClassifierEvent.TYPE_ACTIONS_SHOWN, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 1000);
- sendEvent(
- TextClassifierEvent.TYPE_SMART_ACTION, PRIMARY_SYSTEM_LANGUAGE, /* times= */ 1000);
-
- assertThat(mProficiencyAnalyzer.shouldShowTranslation(PRIMARY_SYSTEM_LANGUAGE))
- .isEqualTo(false);
- }
-
- private void sendEvent(int type, String languageTag, int times) {
- TextClassifierEvent.LanguageDetectionEvent event =
- new TextClassifierEvent.LanguageDetectionEvent.Builder(type)
- .setEntityTypes(languageTag)
- .setActionIndices(0)
- .build();
- for (int i = 0; i < times; i++) {
- mProficiencyAnalyzer.onTextClassifierEvent(event);
- }
- }
-}
diff --git a/java/tests/unittests/src/com/android/textclassifier/ulp/SystemLanguagesProviderTest.java b/java/tests/unittests/src/com/android/textclassifier/ulp/SystemLanguagesProviderTest.java
deleted file mode 100644
index 8117bdd..0000000
--- a/java/tests/unittests/src/com/android/textclassifier/ulp/SystemLanguagesProviderTest.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.ulp;
-
-import static com.google.common.truth.Truth.assertThat;
-
-import android.content.res.Resources;
-import android.os.LocaleList;
-
-import androidx.test.filters.SmallTest;
-
-import org.junit.Before;
-import org.junit.Test;
-
-import java.util.List;
-import java.util.Locale;
-
-@SmallTest
-public class SystemLanguagesProviderTest {
- private SystemLanguagesProvider mSystemLanguagesProvider;
-
- @Before
- public void setup() {
- mSystemLanguagesProvider = new SystemLanguagesProvider();
- }
-
- @Test
- public void getSystemLanguageTags_singleLanguages() {
- Resources.getSystem().getConfiguration().setLocales(new LocaleList(Locale.FRANCE));
-
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
-
- assertThat(systemLanguageTags).containsExactly("fr");
- }
-
- @Test
- public void getSystemLanguageTags_multipleLanguages() {
- Resources.getSystem()
- .getConfiguration()
- .setLocales(new LocaleList(Locale.FRANCE, Locale.ENGLISH));
-
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
-
- assertThat(systemLanguageTags).containsExactly("fr", "en");
- }
-}
diff --git a/jni/com/google/android/textclassifier/LangIdModel.java b/jni/com/google/android/textclassifier/LangIdModel.java
index e3f7a79..9701492 100644
--- a/jni/com/google/android/textclassifier/LangIdModel.java
+++ b/jni/com/google/android/textclassifier/LangIdModel.java
@@ -108,6 +108,11 @@
return nativeGetLangIdNoiseThreshold(modelPtr);
}
+ // Visible for testing.
+ int getMinTextSizeInBytes() {
+ return nativeGetMinTextSizeInBytes(modelPtr);
+ }
+
private static native long nativeNew(int fd);
private static native long nativeNewFromPath(String path);
@@ -123,4 +128,6 @@
private native float nativeGetLangIdThreshold(long nativePtr);
private native float nativeGetLangIdNoiseThreshold(long nativePtr);
+
+ private native int nativeGetMinTextSizeInBytes(long nativePtr);
}
diff --git a/native/Android.bp b/native/Android.bp
index 5f6892c..d4ac22d 100644
--- a/native/Android.bp
+++ b/native/Android.bp
@@ -99,6 +99,7 @@
"libtextclassifier_fbgen_lang_id_embedded_network",
"libtextclassifier_fbgen_lang_id_model",
"libtextclassifier_fbgen_actions-entity-data",
+ "libtextclassifier_fbgen_normalization",
],
header_libs: [
@@ -221,6 +222,13 @@
defaults: ["fbgen"],
}
+genrule {
+ name: "libtextclassifier_fbgen_normalization",
+ srcs: ["utils/normalization.fbs"],
+ out: ["utils/normalization_generated.h"],
+ defaults: ["fbgen"],
+}
+
// -----------------
// libtextclassifier
// -----------------
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index e651d19..113d9e9 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -24,6 +24,7 @@
#include "utils/base/logging.h"
#include "utils/flatbuffers.h"
#include "utils/lua-utils.h"
+#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/regex-match.h"
#include "utils/strings/split.h"
@@ -923,7 +924,7 @@
if (model_->tflite_model_spec()->output_actions_scores() >= 0) {
const TensorView<float> actions_scores = model_executor_->OutputView<float>(
model_->tflite_model_spec()->output_actions_scores(), interpreter);
- for (int i = 0; i < model_->action_type()->Length(); i++) {
+ for (int i = 0; i < model_->action_type()->size(); i++) {
const ActionTypeOptions* action_type = model_->action_type()->Get(i);
// Skip disabled action classes, such as the default other category.
if (!action_type->enabled()) {
@@ -1139,7 +1140,18 @@
suggestion.serialized_entity_data.size()));
}
- entity_data->ParseAndSet(mapping->entity_field(), annotation.span.text);
+ UnicodeText normalized_annotation_text =
+ UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false);
+
+ // Apply normalization if specified.
+ if (mapping->normalization_options() != nullptr) {
+ normalized_annotation_text =
+ NormalizeText(unilib_, mapping->normalization_options(),
+ normalized_annotation_text);
+ }
+
+ entity_data->ParseAndSet(mapping->entity_field(),
+ normalized_annotation_text.ToUTF8String());
suggestion.serialized_entity_data = entity_data->Serialize();
}
@@ -1251,11 +1263,22 @@
continue;
}
+ UnicodeText normalized_group_match_text =
+ UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
+
+ // Apply normalization if specified.
+ if (group->normalization_options() != nullptr) {
+ normalized_group_match_text =
+ NormalizeText(unilib_, group->normalization_options(),
+ normalized_group_match_text);
+ }
+
if (group->entity_field() != nullptr) {
TC3_CHECK(entity_data != nullptr);
sets_entity_data = true;
- if (!entity_data->ParseAndSet(group->entity_field(),
- group_match_text.value())) {
+ if (!entity_data->ParseAndSet(
+ group->entity_field(),
+ normalized_group_match_text.ToUTF8String())) {
TC3_LOG(ERROR)
<< "Could not set entity data from rule capturing group.";
return false;
@@ -1275,7 +1298,8 @@
actions->push_back(SuggestionFromSpec(
group->text_reply(),
/*default_type=*/model_->smart_reply_action_type()->str(),
- /*default_response_text=*/group_match_text.value()));
+ /*default_response_text=*/
+ normalized_group_match_text.ToUTF8String()));
}
}
}
@@ -1412,6 +1436,15 @@
const Conversation& conversation, const Annotator* annotator,
const ActionSuggestionOptions& options) const {
ActionsSuggestionsResponse response;
+
+ // Assert that messages are sorted correctly.
+ for (int i = 1; i < conversation.messages.size(); i++) {
+ if (conversation.messages[i].reference_time_ms_utc <
+ conversation.messages[i - 1].reference_time_ms_utc) {
+ TC3_LOG(ERROR) << "Messages are not sorted most recent last.";
+ }
+ }
+
if (!GatherActionsSuggestions(conversation, annotator, options, &response)) {
TC3_LOG(ERROR) << "Could not gather actions suggestions.";
response.actions.clear();
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 0dc627b..8828ddb 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -157,6 +157,62 @@
"home");
}
+TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsNormalization) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ SetTestEntityDataSchema(actions_model.get());
+
+ // Set custom actions from annotations config.
+ actions_model->annotation_actions_spec->annotation_mapping.clear();
+ actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
+ new AnnotationActionsSpec_::AnnotationMappingT);
+ AnnotationActionsSpec_::AnnotationMappingT* mapping =
+ actions_model->annotation_actions_spec->annotation_mapping.back().get();
+ mapping->annotation_collection = "address";
+ mapping->action.reset(new ActionSuggestionSpecT);
+ mapping->action->type = "save_location";
+ mapping->action->score = 1.0;
+ mapping->action->priority_score = 2.0;
+ mapping->entity_field.reset(new FlatbufferFieldPathT);
+ mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
+ mapping->entity_field->field.back()->field_name = "location";
+ mapping->normalization_options.reset(new NormalizationOptionsT);
+ mapping->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ AnnotatedSpan annotation;
+ annotation.span = {11, 15};
+ annotation.classification = {ClassificationResult("address", 1.0)};
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "are you at home?",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{annotation},
+ /*locales=*/"en"}}});
+ ASSERT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions.front().type, "save_location");
+ EXPECT_EQ(response.actions.front().score, 1.0);
+
+ // Check that the `location` entity field holds the normalized text of the
+ // annotation.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions.front().serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
+ "HOME");
+}
+
TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) {
std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel();
AnnotatedSpan flight_annotation;
@@ -789,6 +845,66 @@
"Kenobi");
}
+TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
+ const std::string actions_model_string =
+ ReadFile(GetModelPath() + kModelFileName);
+ std::unique_ptr<ActionsModelT> actions_model =
+ UnPackActionsModel(actions_model_string.c_str());
+ ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
+
+ actions_model->rules.reset(new RulesModelT());
+ actions_model->rules->rule.emplace_back(new RulesModel_::RuleT);
+ RulesModel_::RuleT* rule = actions_model->rules->rule.back().get();
+ rule->pattern = "^(?i:hello\\sthere)$";
+ rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT);
+ rule->actions.back()->action.reset(new ActionSuggestionSpecT);
+ ActionSuggestionSpecT* action = rule->actions.back()->action.get();
+ action->type = "text_reply";
+ action->response_text = "General Kenobi!";
+ action->score = 1.0f;
+ action->priority_score = 1.0f;
+
+ // Set capturing groups for entity data.
+ rule->actions.back()->capturing_group.emplace_back(
+ new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT);
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
+ rule->actions.back()->capturing_group.back().get();
+ greeting_group->group_id = 0;
+ greeting_group->entity_field.reset(new FlatbufferFieldPathT);
+ greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
+ greeting_group->entity_field->field.back()->field_name = "greeting";
+ greeting_group->normalization_options.reset(new NormalizationOptionsT);
+ greeting_group->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
+
+ // Set test entity data schema.
+ SetTestEntityDataSchema(actions_model.get());
+
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, actions_model.get()));
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ ActionsSuggestions::FromUnownedBuffer(
+ reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
+ builder.GetSize(), &unilib_);
+
+ const ActionsSuggestionsResponse& response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_GE(response.actions.size(), 1);
+ EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
+
+ // Check entity data.
+ const flatbuffers::Table* entity =
+ flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
+ response.actions[0].serialized_entity_data.data()));
+ EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
+ "HELLOTHERE");
+}
+
TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
@@ -811,6 +927,9 @@
code_group->text_reply.reset(new ActionSuggestionSpecT);
code_group->text_reply->score = 1.0f;
code_group->text_reply->priority_score = 1.0f;
+ code_group->normalization_options.reset(new NormalizationOptionsT);
+ code_group->normalization_options->codepointwise_normalization =
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
@@ -828,7 +947,7 @@
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_GE(response.actions.size(), 1);
- EXPECT_EQ(response.actions[0].response_text, "STOP");
+ EXPECT_EQ(response.actions[0].response_text, "stop");
}
TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
@@ -1076,7 +1195,7 @@
TestingMessageEmbedder CreateTestingMessageEmbedder() {
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
- buffer_ = builder.ReleaseBufferPointer();
+ buffer_ = builder.Release();
return TestingMessageEmbedder(
flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
}
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
index 8284921..2cfcbf2 100644
--- a/native/actions/actions_jni.cc
+++ b/native/actions/actions_jni.cc
@@ -19,6 +19,7 @@
#include "actions/actions_jni.h"
#include <jni.h>
+
#include <map>
#include <type_traits>
#include <vector>
@@ -27,10 +28,12 @@
#include "annotator/annotator.h"
#include "annotator/annotator_jni_common.h"
#include "utils/base/integral_types.h"
+#include "utils/base/statusor.h"
#include "utils/intents/intent-generator.h"
#include "utils/intents/jni.h"
+#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/java/jni-helper.h"
#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
@@ -42,6 +45,7 @@
using libtextclassifier3::Conversation;
using libtextclassifier3::IntentGenerator;
using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::StatusOr;
using libtextclassifier3::ToStlString;
// When using the Java's ICU, UniLib needs to be instantiated with a JavaVM
@@ -116,52 +120,59 @@
return options;
}
-jobjectArray ActionSuggestionsToJObjectArray(
+StatusOr<ScopedLocalRef<jobjectArray>> ActionSuggestionsToJObjectArray(
JNIEnv* env, const ActionsSuggestionsJniContext* context,
jobject app_context,
const reflection::Schema* annotations_entity_data_schema,
const std::vector<ActionSuggestion>& action_result,
const Conversation& conversation, const jstring device_locales,
const bool generate_intents) {
- const ScopedLocalRef<jclass> result_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
- "$ActionSuggestion"),
- env);
- if (!result_class) {
+ auto status_or_result_class = JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ActionSuggestion");
+ if (!status_or_result_class.ok()) {
TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
- return nullptr;
+ return status_or_result_class.status();
}
+ ScopedLocalRef<jclass> result_class =
+ std::move(status_or_result_class.ValueOrDie());
- const jmethodID result_class_constructor = env->GetMethodID(
- result_class.get(), "<init>",
- "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
- TC3_NAMED_VARIANT_CLASS_NAME_STR
- ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";)V");
- const jobjectArray results =
- env->NewObjectArray(action_result.size(), result_class.get(), nullptr);
+ TC3_ASSIGN_OR_RETURN(
+ const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(
+ env, result_class.get(), "<init>",
+ "(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
+ TC3_NAMED_VARIANT_CLASS_NAME_STR
+ ";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
+ ";)V"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, action_result.size(),
+ result_class.get(), nullptr));
for (int i = 0; i < action_result.size(); i++) {
- jobject extras = nullptr;
-
+ ScopedLocalRef<jobjectArray> extras;
const reflection::Schema* actions_entity_data_schema =
context->model()->entity_data_schema();
if (actions_entity_data_schema != nullptr &&
!action_result[i].serialized_entity_data.empty()) {
- extras = context->template_handler()->EntityDataAsNamedVariantArray(
- actions_entity_data_schema, action_result[i].serialized_entity_data);
+ TC3_ASSIGN_OR_RETURN(
+ extras, context->template_handler()->EntityDataAsNamedVariantArray(
+ actions_entity_data_schema,
+ action_result[i].serialized_entity_data));
}
- jbyteArray serialized_entity_data = nullptr;
+ ScopedLocalRef<jbyteArray> serialized_entity_data;
if (!action_result[i].serialized_entity_data.empty()) {
- serialized_entity_data =
- env->NewByteArray(action_result[i].serialized_entity_data.size());
+ TC3_ASSIGN_OR_RETURN(
+ serialized_entity_data,
+ JniHelper::NewByteArray(
+ env, action_result[i].serialized_entity_data.size()));
env->SetByteArrayRegion(
- serialized_entity_data, 0,
+ serialized_entity_data.get(), 0,
action_result[i].serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
action_result[i].serialized_entity_data.data()));
}
- jobject remote_action_templates_result = nullptr;
+ ScopedLocalRef<jobjectArray> remote_action_templates_result;
if (generate_intents) {
std::vector<RemoteActionTemplate> remote_action_templates;
if (context->intent_generator()->GenerateIntents(
@@ -169,96 +180,127 @@
/*annotations_entity_data_schema=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
&remote_action_templates)) {
- remote_action_templates_result =
+ TC3_ASSIGN_OR_RETURN(
+ remote_action_templates_result,
context->template_handler()->RemoteActionTemplatesToJObjectArray(
- remote_action_templates);
+ remote_action_templates));
}
}
- ScopedLocalRef<jstring> reply = context->jni_cache()->ConvertToJavaString(
- action_result[i].response_text);
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reply,
+ context->jni_cache()->ConvertToJavaString(
+ action_result[i].response_text));
- ScopedLocalRef<jobject> result(env->NewObject(
- result_class.get(), result_class_constructor, reply.get(),
- env->NewStringUTF(action_result[i].type.c_str()),
- static_cast<jfloat>(action_result[i].score), extras,
- serialized_entity_data, remote_action_templates_result));
- env->SetObjectArrayElement(results, i, result.get());
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> action_type,
+ JniHelper::NewStringUTF(env, action_result[i].type.c_str()));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(env, result_class.get(), result_class_constructor,
+ reply.get(), action_type.get(),
+ static_cast<jfloat>(action_result[i].score),
+ extras.get(), serialized_entity_data.get(),
+ remote_action_templates_result.get()));
+ env->SetObjectArrayElement(results.get(), i, result.get());
}
return results;
}
-ConversationMessage FromJavaConversationMessage(JNIEnv* env, jobject jmessage) {
+StatusOr<ConversationMessage> FromJavaConversationMessage(JNIEnv* env,
+ jobject jmessage) {
if (!jmessage) {
return {};
}
- const ScopedLocalRef<jclass> message_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
- "$ConversationMessage"),
- env);
- const std::pair<bool, jobject> status_or_text = CallJniMethod0<jobject>(
- env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod, "getText",
- "Ljava/lang/String;");
- const std::pair<bool, int32> status_or_user_id =
- CallJniMethod0<int32>(env, jmessage, message_class.get(),
- &JNIEnv::CallIntMethod, "getUserId", "I");
- const std::pair<bool, int64> status_or_reference_time = CallJniMethod0<int64>(
- env, jmessage, message_class.get(), &JNIEnv::CallLongMethod,
- "getReferenceTimeMsUtc", "J");
- const std::pair<bool, jobject> status_or_reference_timezone =
- CallJniMethod0<jobject>(env, jmessage, message_class.get(),
- &JNIEnv::CallObjectMethod, "getReferenceTimezone",
- "Ljava/lang/String;");
- const std::pair<bool, jobject> status_or_detected_text_language_tags =
- CallJniMethod0<jobject>(
- env, jmessage, message_class.get(), &JNIEnv::CallObjectMethod,
- "getDetectedTextLanguageTags", "Ljava/lang/String;");
- if (!status_or_text.first || !status_or_user_id.first ||
- !status_or_detected_text_language_tags.first ||
- !status_or_reference_time.first || !status_or_reference_timezone.first) {
- return {};
- }
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> message_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ConversationMessage"));
+ // .getText()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_text_method,
+ JniHelper::GetMethodID(env, message_class.get(), "getText",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> text,
+ JniHelper::CallObjectMethod<jstring>(env, jmessage, get_text_method));
+
+ // .getUserId()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_user_id_method,
+ JniHelper::GetMethodID(env, message_class.get(), "getUserId", "()I"));
+ TC3_ASSIGN_OR_RETURN(int32 user_id, JniHelper::CallIntMethod(
+ env, jmessage, get_user_id_method));
+
+ // .getReferenceTimeMsUtc()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_reference_time_method,
+ JniHelper::GetMethodID(env, message_class.get(),
+ "getReferenceTimeMsUtc", "()J"));
+ TC3_ASSIGN_OR_RETURN(
+ int64 reference_time,
+ JniHelper::CallLongMethod(env, jmessage, get_reference_time_method));
+
+ // .getReferenceTimezone()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_reference_timezone_method,
+ JniHelper::GetMethodID(env, message_class.get(), "getReferenceTimezone",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reference_timezone,
+ JniHelper::CallObjectMethod<jstring>(
+ env, jmessage, get_reference_timezone_method));
+
+ // .getDetectedTextLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method,
+ JniHelper::GetMethodID(env, message_class.get(),
+ "getDetectedTextLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> detected_text_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, jmessage, get_detected_text_language_tags_method));
ConversationMessage message;
- message.text = ToStlString(env, static_cast<jstring>(status_or_text.second));
- message.user_id = status_or_user_id.second;
- message.reference_time_ms_utc = status_or_reference_time.second;
- message.reference_timezone = ToStlString(
- env, static_cast<jstring>(status_or_reference_timezone.second));
- message.detected_text_language_tags = ToStlString(
- env, static_cast<jstring>(status_or_detected_text_language_tags.second));
+ TC3_ASSIGN_OR_RETURN(message.text, ToStlString(env, text.get()));
+ message.user_id = user_id;
+ message.reference_time_ms_utc = reference_time;
+ TC3_ASSIGN_OR_RETURN(message.reference_timezone,
+ ToStlString(env, reference_timezone.get()));
+ TC3_ASSIGN_OR_RETURN(message.detected_text_language_tags,
+ ToStlString(env, detected_text_language_tags.get()));
return message;
}
-Conversation FromJavaConversation(JNIEnv* env, jobject jconversation) {
+StatusOr<Conversation> FromJavaConversation(JNIEnv* env,
+ jobject jconversation) {
if (!jconversation) {
- return {};
+ return {Status::UNKNOWN};
}
- const ScopedLocalRef<jclass> conversation_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
- "$Conversation"),
- env);
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> conversation_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$Conversation"));
- const std::pair<bool, jobject> status_or_messages = CallJniMethod0<jobject>(
- env, jconversation, conversation_class.get(), &JNIEnv::CallObjectMethod,
- "getConversationMessages",
- "[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ConversationMessage;");
-
- if (!status_or_messages.first) {
- return {};
- }
-
- const jobjectArray jmessages =
- reinterpret_cast<jobjectArray>(status_or_messages.second);
-
- const int size = env->GetArrayLength(jmessages);
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_conversation_messages_method,
+ JniHelper::GetMethodID(env, conversation_class.get(),
+ "getConversationMessages",
+ "()[L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ConversationMessage;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> jmessages,
+ JniHelper::CallObjectMethod<jobjectArray>(
+ env, jconversation, get_conversation_messages_method));
std::vector<ConversationMessage> messages;
+ const int size = env->GetArrayLength(jmessages.get());
for (int i = 0; i < size; i++) {
- jobject jmessage = env->GetObjectArrayElement(jmessages, i);
- ConversationMessage message = FromJavaConversationMessage(env, jmessage);
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> jmessage,
+ JniHelper::GetObjectArrayElement<jobject>(env, jmessages.get(), i));
+ TC3_ASSIGN_OR_RETURN(ConversationMessage message,
+ FromJavaConversationMessage(env, jmessage.get()));
messages.push_back(message);
}
Conversation conversation;
@@ -266,16 +308,17 @@
return conversation;
}
-jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->locales()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
- return env->NewStringUTF(model->locales()->c_str());
+ return JniHelper::NewStringUTF(env, model->locales()->c_str());
}
jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
@@ -290,16 +333,17 @@
return model->version();
}
-jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+StatusOr<ScopedLocalRef<jstring>> GetNameFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
const ActionsModel* model = libtextclassifier3::ViewActionsModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->name()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
- return env->NewStringUTF(model->name()->c_str());
+ return JniHelper::NewStringUTF(env, model->name()->c_str());
}
} // namespace
} // namespace libtextclassifier3
@@ -336,7 +380,7 @@
(JNIEnv* env, jobject thiz, jstring path, jbyteArray serialized_preconditions) {
std::shared_ptr<libtextclassifier3::JniCache> jni_cache =
libtextclassifier3::JniCache::Create(env);
- const std::string path_str = ToStlString(env, path);
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
std::string preconditions;
if (serialized_preconditions != nullptr &&
!libtextclassifier3::JByteArrayToString(env, serialized_preconditions,
@@ -388,7 +432,8 @@
if (!ptr) {
return nullptr;
}
- const Conversation conversation = FromJavaConversation(env, jconversation);
+ TC3_ASSIGN_OR_RETURN_NULL(const Conversation conversation,
+ FromJavaConversation(env, jconversation));
const ActionSuggestionOptions options =
FromJavaActionSuggestionOptions(env, joptions);
const ActionsSuggestionsJniContext* context =
@@ -400,9 +445,13 @@
const reflection::Schema* anntotations_entity_data_schema =
annotator ? annotator->entity_data_schema() : nullptr;
- return ActionSuggestionsToJObjectArray(
- env, context, app_context, anntotations_entity_data_schema,
- response.actions, conversation, device_locales, generate_intents);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> result,
+ ActionSuggestionsToJObjectArray(
+ env, context, app_context, anntotations_entity_data_schema,
+ response.actions, conversation, device_locales, generate_intents));
+ return result.release();
}
TC3_JNI_METHOD(void, TC3_ACTIONS_CLASS_NAME, nativeCloseActionsModel)
@@ -416,28 +465,40 @@
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
- return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetLocalesFromMmap(env, mmap.get()));
+ return result.release();
}
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetLocalesWithOffset)
(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
- return libtextclassifier3::GetLocalesFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetLocalesFromMmap(env, mmap.get()));
+ return result.release();
}
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetName)
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
- return libtextclassifier3::GetNameFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetNameFromMmap(env, mmap.get()));
+ return result.release();
}
TC3_JNI_METHOD(jstring, TC3_ACTIONS_CLASS_NAME, nativeGetNameWithOffset)
(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
- return libtextclassifier3::GetNameFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jstring> result,
+ libtextclassifier3::GetNameFromMmap(env, mmap.get()));
+ return result.release();
}
TC3_JNI_METHOD(jint, TC3_ACTIONS_CLASS_NAME, nativeGetVersion)
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index d939d7c..fa76c43 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -18,6 +18,7 @@
include "utils/codepoint-range.fbs";
include "utils/flatbuffers.fbs";
include "utils/intents/intent-config.fbs";
+include "utils/normalization.fbs";
include "utils/resources.fbs";
include "utils/tokenizer.fbs";
include "utils/zlib/buffer.fbs";
@@ -304,6 +305,9 @@
// If set, the text of the annotation will be used to set a field in the
// action entity data.
entity_field:FlatbufferFieldPath;
+
+ // If set, normalization to apply to the annotation text.
+ normalization_options:NormalizationOptions;
}
// Configuration for actions based on annotatations.
@@ -383,6 +387,9 @@
// If set, the capturing group text will be used to create a text
// reply.
text_reply:ActionSuggestionSpec;
+
+ // If set, normalization to apply to the capturing group text.
+ normalization_options:NormalizationOptions;
}
// The actions to produce upon triggering.
diff --git a/native/actions/flatbuffer-utils.cc b/native/actions/flatbuffer-utils.cc
new file mode 100644
index 0000000..b91fb40
--- /dev/null
+++ b/native/actions/flatbuffer-utils.cc
@@ -0,0 +1,88 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "actions/flatbuffer-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+bool SwapFieldNamesForOffsetsInPathInActionsModel(ActionsModelT* model) {
+ if (model->actions_entity_data_schema.empty()) {
+ // Nothing to do.
+ return true;
+ }
+
+ const reflection::Schema* schema =
+ LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model->actions_entity_data_schema.data(),
+ model->actions_entity_data_schema.size());
+
+ // Resolve offsets in regex rules.
+ if (model->rules != nullptr) {
+ for (std::unique_ptr<RulesModel_::RuleT>& rule : model->rules->rule) {
+ for (std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT>& rule_action :
+ rule->actions) {
+ for (std::unique_ptr<
+ RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT>&
+ capturing_group : rule_action->capturing_group) {
+ if (capturing_group->entity_field == nullptr) {
+ continue;
+ }
+ if (!SwapFieldNamesForOffsetsInPath(
+ schema, capturing_group->entity_field.get())) {
+ return false;
+ }
+ }
+ }
+ }
+ }
+
+ // Resolve offsets in annotation action mapping.
+ if (model->annotation_actions_spec != nullptr) {
+ for (std::unique_ptr<AnnotationActionsSpec_::AnnotationMappingT>& mapping :
+ model->annotation_actions_spec->annotation_mapping) {
+ if (mapping->entity_field == nullptr) {
+ continue;
+ }
+ if (!SwapFieldNamesForOffsetsInPath(schema,
+ mapping->entity_field.get())) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+std::string SwapFieldNamesForOffsetsInPathInSerializedActionsModel(
+ const std::string& model) {
+ std::unique_ptr<ActionsModelT> unpacked_model =
+ UnPackActionsModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(SwapFieldNamesForOffsetsInPathInActionsModel(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishActionsModelBuffer(builder,
+ ActionsModel::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
diff --git a/native/actions/flatbuffer-utils.h b/native/actions/flatbuffer-utils.h
new file mode 100644
index 0000000..2479599
--- /dev/null
+++ b/native/actions/flatbuffer-utils.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers in the actions model.
+
+#ifndef LIBTEXTCLASSIFIER_ACTIONS_FLATBUFFER_UTILS_H_
+#define LIBTEXTCLASSIFIER_ACTIONS_FLATBUFFER_UTILS_H_
+
+#include <string>
+
+#include "actions/actions_model_generated.h"
+
+namespace libtextclassifier3 {
+
+// Resolves field lookups by name to the concrete field offsets in the regex
+// rules of the model.
+bool SwapFieldNamesForOffsetsInPathInActionsModel(ActionsModelT* model);
+
+// Same as above but for a serialized model.
+std::string SwapFieldNamesForOffsetsInPathInSerializedActionsModel(
+ const std::string& model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ACTIONS_FLATBUFFER_UTILS_H_
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 867eea0..d910c8d 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -29,6 +29,7 @@
#include "utils/base/logging.h"
#include "utils/checksum.h"
#include "utils/math/softmax.h"
+#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/regex-match.h"
#include "utils/utf8/unicodetext.h"
@@ -416,7 +417,7 @@
model_->duration_annotator_options()->enabled()) {
duration_annotator_.reset(
new DurationAnnotator(model_->duration_annotator_options(),
- selection_feature_processor_.get()));
+ selection_feature_processor_.get(), unilib_));
}
if (model_->entity_data_schema()) {
@@ -505,6 +506,10 @@
TC3_LOG(ERROR) << "Failed to initialize the knowledge engine.";
return false;
}
+ if (model_->triggering_options() != nullptr) {
+ knowledge_engine->SetPriorityScore(
+ model_->triggering_options()->knowledge_priority_score());
+ }
knowledge_engine_ = std::move(knowledge_engine);
return true;
}
@@ -2075,8 +2080,19 @@
// Set entity field from capturing group text.
if (group->entity_field_path() != nullptr) {
- if (!entity_data->ParseAndSet(group->entity_field_path(),
- group_match_text.value())) {
+ UnicodeText normalized_group_match_text =
+ UTF8ToUnicodeText(group_match_text.value(), /*do_copy=*/false);
+
+ // Apply normalization if specified.
+ if (group->normalization_options() != nullptr) {
+ normalized_group_match_text =
+ NormalizeText(unilib_, group->normalization_options(),
+ normalized_group_match_text);
+ }
+
+ if (!entity_data->ParseAndSet(
+ group->entity_field_path(),
+ normalized_group_match_text.ToUTF8String())) {
TC3_LOG(ERROR)
<< "Could not set entity data from rule capturing group.";
return false;
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index e5b7833..28be366 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -19,6 +19,7 @@
#include "annotator/annotator_jni.h"
#include <jni.h>
+
#include <type_traits>
#include <vector>
@@ -26,11 +27,12 @@
#include "annotator/annotator_jni_common.h"
#include "annotator/types.h"
#include "utils/base/integral_types.h"
+#include "utils/base/statusor.h"
#include "utils/calendar/calendar.h"
#include "utils/intents/intent-generator.h"
#include "utils/intents/jni.h"
#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/java/jni-helper.h"
#include "utils/java/string_utils.h"
#include "utils/memory/mmap.h"
#include "utils/strings/stringpiece.h"
@@ -48,8 +50,10 @@
using libtextclassifier3::Annotator;
using libtextclassifier3::ClassificationResult;
using libtextclassifier3::CodepointSpan;
+using libtextclassifier3::JniHelper;
using libtextclassifier3::Model;
using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::StatusOr;
// When using the Java's ICU, CalendarLib and UniLib need to be instantiated
// with a JavaVM pointer from JNI. When using a standard ICU the pointer is
// not needed and the objects are instantiated implicitly.
@@ -71,6 +75,7 @@
if (jni_cache == nullptr || model == nullptr) {
return nullptr;
}
+ // Intent generator will be null if the options are not specified.
std::unique_ptr<IntentGenerator> intent_generator =
IntentGenerator::Create(model->model()->intent_options(),
model->model()->resources(), jni_cache);
@@ -79,6 +84,7 @@
if (template_handler == nullptr) {
return nullptr;
}
+
return new AnnotatorJniContext(jni_cache, std::move(model),
std::move(intent_generator),
std::move(template_handler));
@@ -90,6 +96,8 @@
Annotator* model() const { return model_.get(); }
+ // NOTE: Intent generator will be null if the options are not specified in
+ // the model.
IntentGenerator* intent_generator() const { return intent_generator_.get(); }
RemoteActionTemplatesHandler* template_handler() const {
@@ -113,184 +121,217 @@
std::unique_ptr<RemoteActionTemplatesHandler> template_handler_;
};
-jobject ClassificationResultWithIntentsToJObject(
+StatusOr<ScopedLocalRef<jobject>> ClassificationResultWithIntentsToJObject(
JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
jclass result_class, jmethodID result_class_constructor,
jclass datetime_parse_class, jmethodID datetime_parse_class_constructor,
const jstring device_locales, const ClassificationOptions* options,
const std::string& context, const CodepointSpan& selection_indices,
const ClassificationResult& classification_result, bool generate_intents) {
- jstring row_string =
- env->NewStringUTF(classification_result.collection.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> row_string,
+ JniHelper::NewStringUTF(env, classification_result.collection.c_str()));
- jobject row_datetime_parse = nullptr;
+ ScopedLocalRef<jobject> row_datetime_parse;
if (classification_result.datetime_parse_result.IsSet()) {
- row_datetime_parse =
- env->NewObject(datetime_parse_class, datetime_parse_class_constructor,
- classification_result.datetime_parse_result.time_ms_utc,
- classification_result.datetime_parse_result.granularity);
+ TC3_ASSIGN_OR_RETURN(
+ row_datetime_parse,
+ JniHelper::NewObject(
+ env, datetime_parse_class, datetime_parse_class_constructor,
+ classification_result.datetime_parse_result.time_ms_utc,
+ classification_result.datetime_parse_result.granularity));
}
- jbyteArray serialized_knowledge_result = nullptr;
+ ScopedLocalRef<jbyteArray> serialized_knowledge_result;
const std::string& serialized_knowledge_result_string =
classification_result.serialized_knowledge_result;
if (!serialized_knowledge_result_string.empty()) {
- serialized_knowledge_result =
- env->NewByteArray(serialized_knowledge_result_string.size());
- env->SetByteArrayRegion(serialized_knowledge_result, 0,
+ TC3_ASSIGN_OR_RETURN(serialized_knowledge_result,
+ JniHelper::NewByteArray(
+ env, serialized_knowledge_result_string.size()));
+ env->SetByteArrayRegion(serialized_knowledge_result.get(), 0,
serialized_knowledge_result_string.size(),
reinterpret_cast<const jbyte*>(
serialized_knowledge_result_string.data()));
}
- jstring contact_name = nullptr;
+ ScopedLocalRef<jstring> contact_name;
if (!classification_result.contact_name.empty()) {
- contact_name =
- env->NewStringUTF(classification_result.contact_name.c_str());
+ TC3_ASSIGN_OR_RETURN(contact_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_name.c_str()));
}
- jstring contact_given_name = nullptr;
+ ScopedLocalRef<jstring> contact_given_name;
if (!classification_result.contact_given_name.empty()) {
- contact_given_name =
- env->NewStringUTF(classification_result.contact_given_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_given_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_given_name.c_str()));
}
- jstring contact_family_name = nullptr;
+ ScopedLocalRef<jstring> contact_family_name;
if (!classification_result.contact_family_name.empty()) {
- contact_family_name =
- env->NewStringUTF(classification_result.contact_family_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_family_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_family_name.c_str()));
}
- jstring contact_nickname = nullptr;
+ ScopedLocalRef<jstring> contact_nickname;
if (!classification_result.contact_nickname.empty()) {
- contact_nickname =
- env->NewStringUTF(classification_result.contact_nickname.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_nickname,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_nickname.c_str()));
}
- jstring contact_email_address = nullptr;
+ ScopedLocalRef<jstring> contact_email_address;
if (!classification_result.contact_email_address.empty()) {
- contact_email_address =
- env->NewStringUTF(classification_result.contact_email_address.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_email_address,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_email_address.c_str()));
}
- jstring contact_phone_number = nullptr;
+ ScopedLocalRef<jstring> contact_phone_number;
if (!classification_result.contact_phone_number.empty()) {
- contact_phone_number =
- env->NewStringUTF(classification_result.contact_phone_number.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_phone_number,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_phone_number.c_str()));
}
- jstring contact_id = nullptr;
+ ScopedLocalRef<jstring> contact_id;
if (!classification_result.contact_id.empty()) {
- contact_id = env->NewStringUTF(classification_result.contact_id.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ contact_id,
+ JniHelper::NewStringUTF(env, classification_result.contact_id.c_str()));
}
- jstring app_name = nullptr;
+ ScopedLocalRef<jstring> app_name;
if (!classification_result.app_name.empty()) {
- app_name = env->NewStringUTF(classification_result.app_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ app_name,
+ JniHelper::NewStringUTF(env, classification_result.app_name.c_str()));
}
- jstring app_package_name = nullptr;
+ ScopedLocalRef<jstring> app_package_name;
if (!classification_result.app_package_name.empty()) {
- app_package_name =
- env->NewStringUTF(classification_result.app_package_name.c_str());
+ TC3_ASSIGN_OR_RETURN(
+ app_package_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.app_package_name.c_str()));
}
- jobject extras = nullptr;
+ ScopedLocalRef<jobjectArray> extras;
if (model_context->model()->entity_data_schema() != nullptr &&
!classification_result.serialized_entity_data.empty()) {
- extras = model_context->template_handler()->EntityDataAsNamedVariantArray(
- model_context->model()->entity_data_schema(),
- classification_result.serialized_entity_data);
+ TC3_ASSIGN_OR_RETURN(
+ extras,
+ model_context->template_handler()->EntityDataAsNamedVariantArray(
+ model_context->model()->entity_data_schema(),
+ classification_result.serialized_entity_data));
}
- jbyteArray serialized_entity_data = nullptr;
+ ScopedLocalRef<jbyteArray> serialized_entity_data;
if (!classification_result.serialized_entity_data.empty()) {
- serialized_entity_data =
- env->NewByteArray(classification_result.serialized_entity_data.size());
+ TC3_ASSIGN_OR_RETURN(
+ serialized_entity_data,
+ JniHelper::NewByteArray(
+ env, classification_result.serialized_entity_data.size()));
env->SetByteArrayRegion(
- serialized_entity_data, 0,
+ serialized_entity_data.get(), 0,
classification_result.serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
classification_result.serialized_entity_data.data()));
}
- jobject remote_action_templates_result = nullptr;
+ ScopedLocalRef<jobjectArray> remote_action_templates_result;
// Only generate RemoteActionTemplate for the top classification result
// as classifyText does not need RemoteAction from other results anyway.
if (generate_intents && model_context->intent_generator() != nullptr) {
std::vector<RemoteActionTemplate> remote_action_templates;
- if (model_context->intent_generator()->GenerateIntents(
+ if (!model_context->intent_generator()->GenerateIntents(
device_locales, classification_result,
options->reference_time_ms_utc, context, selection_indices,
app_context, model_context->model()->entity_data_schema(),
&remote_action_templates)) {
- remote_action_templates_result =
- model_context->template_handler()
- ->RemoteActionTemplatesToJObjectArray(remote_action_templates);
+ return {Status::UNKNOWN};
}
+
+ TC3_ASSIGN_OR_RETURN(
+ remote_action_templates_result,
+ model_context->template_handler()->RemoteActionTemplatesToJObjectArray(
+ remote_action_templates));
}
- return env->NewObject(
- result_class, result_class_constructor, row_string,
- static_cast<jfloat>(classification_result.score), row_datetime_parse,
- serialized_knowledge_result, contact_name, contact_given_name,
- contact_family_name, contact_nickname, contact_email_address,
- contact_phone_number, contact_id, app_name, app_package_name, extras,
- serialized_entity_data, remote_action_templates_result,
- classification_result.duration_ms, classification_result.numeric_value,
+ return JniHelper::NewObject(
+ env, result_class, result_class_constructor, row_string.get(),
+ static_cast<jfloat>(classification_result.score),
+ row_datetime_parse.get(), serialized_knowledge_result.get(),
+ contact_name.get(), contact_given_name.get(), contact_family_name.get(),
+ contact_nickname.get(), contact_email_address.get(),
+ contact_phone_number.get(), contact_id.get(), app_name.get(),
+ app_package_name.get(), extras.get(), serialized_entity_data.get(),
+ remote_action_templates_result.get(), classification_result.duration_ms,
+ classification_result.numeric_value,
classification_result.numeric_double_value);
}
-jobjectArray ClassificationResultsWithIntentsToJObjectArray(
+StatusOr<ScopedLocalRef<jobjectArray>>
+ClassificationResultsWithIntentsToJObjectArray(
JNIEnv* env, const AnnotatorJniContext* model_context, jobject app_context,
const jstring device_locales, const ClassificationOptions* options,
const std::string& context, const CodepointSpan& selection_indices,
const std::vector<ClassificationResult>& classification_result,
bool generate_intents) {
- const ScopedLocalRef<jclass> result_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$ClassificationResult"),
- env);
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find ClassificationResult class.";
- return nullptr;
- }
- const ScopedLocalRef<jclass> datetime_parse_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult"),
- env);
- if (!datetime_parse_class) {
- TC3_LOG(ERROR) << "Couldn't find DatetimeResult class.";
- return nullptr;
- }
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$ClassificationResult"));
- const jmethodID result_class_constructor = env->GetMethodID(
- result_class.get(), "<init>",
- "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;"
- "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
- "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
- "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V");
- const jmethodID datetime_parse_class_constructor =
- env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V");
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> datetime_parse_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult"));
- const jobjectArray results = env->NewObjectArray(classification_result.size(),
- result_class.get(), nullptr);
+ TC3_ASSIGN_OR_RETURN(
+ const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(
+ env, result_class.get(), "<init>",
+ "(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;"
+ "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
+ "String;"
+ "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
+ "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
+ "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V"));
+ TC3_ASSIGN_OR_RETURN(const jmethodID datetime_parse_class_constructor,
+ JniHelper::GetMethodID(env, datetime_parse_class.get(),
+ "<init>", "(JI)V"));
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, classification_result.size(),
+ result_class.get()));
+
for (int i = 0; i < classification_result.size(); i++) {
- jobject result = ClassificationResultWithIntentsToJObject(
- env, model_context, app_context, result_class.get(),
- result_class_constructor, datetime_parse_class.get(),
- datetime_parse_class_constructor, device_locales, options, context,
- selection_indices, classification_result[i],
- generate_intents && (i == 0));
- env->SetObjectArrayElement(results, i, result);
- env->DeleteLocalRef(result);
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ ClassificationResultWithIntentsToJObject(
+ env, model_context, app_context, result_class.get(),
+ result_class_constructor, datetime_parse_class.get(),
+ datetime_parse_class_constructor, device_locales, options, context,
+ selection_indices, classification_result[i],
+ generate_intents && (i == 0)));
+ env->SetObjectArrayElement(results.get(), i, result.get());
}
return results;
}
-jobjectArray ClassificationResultsToJObjectArray(
+StatusOr<ScopedLocalRef<jobjectArray>> ClassificationResultsToJObjectArray(
JNIEnv* env, const AnnotatorJniContext* model_context,
const std::vector<ClassificationResult>& classification_result) {
return ClassificationResultsWithIntentsToJObjectArray(
@@ -361,16 +402,18 @@
return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true);
}
-jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+StatusOr<ScopedLocalRef<jstring>> GetLocalesFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
const Model* model = libtextclassifier3::ViewModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->locales()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
- return env->NewStringUTF(model->locales()->c_str());
+
+ return JniHelper::NewStringUTF(env, model->locales()->c_str());
}
jint GetVersionFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
@@ -385,16 +428,17 @@
return model->version();
}
-jstring GetNameFromMmap(JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
+StatusOr<ScopedLocalRef<jstring>> GetNameFromMmap(
+ JNIEnv* env, libtextclassifier3::ScopedMmap* mmap) {
if (!mmap->handle().ok()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
const Model* model = libtextclassifier3::ViewModel(
mmap->handle().start(), mmap->handle().num_bytes());
if (!model || !model->name()) {
- return env->NewStringUTF("");
+ return JniHelper::NewStringUTF(env, "");
}
- return env->NewStringUTF(model->name()->c_str());
+ return JniHelper::NewStringUTF(env, model->name()->c_str());
}
} // namespace libtextclassifier3
@@ -427,7 +471,7 @@
TC3_JNI_METHOD(jlong, TC3_ANNOTATOR_CLASS_NAME, nativeNewAnnotatorFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
- const std::string path_str = ToStlString(env, path);
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
std::shared_ptr<libtextclassifier3::JniCache> jni_cache(
libtextclassifier3::JniCache::Create(env));
#ifdef TC3_USE_JAVAICU
@@ -531,17 +575,22 @@
return nullptr;
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- const std::string context_utf8 = ToStlString(env, context);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
- CodepointSpan selection = model->SuggestSelection(
- context_utf8, input_indices, FromJavaSelectionOptions(env, options));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ libtextclassifier3::SelectionOptions selection_options,
+ FromJavaSelectionOptions(env, options));
+ CodepointSpan selection =
+ model->SuggestSelection(context_utf8, input_indices, selection_options);
selection = ConvertIndicesUTF8ToBMP(context_utf8, selection);
- jintArray result = env->NewIntArray(2);
- env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection)));
- env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection)));
- return result;
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jintArray> result,
+ JniHelper::NewIntArray(env, 2));
+ env->SetIntArrayRegion(result.get(), 0, 1, &(std::get<0>(selection)));
+ env->SetIntArrayRegion(result.get(), 1, 1, &(std::get<1>(selection)));
+ return result.release();
}
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeClassifyText)
@@ -554,23 +603,33 @@
const AnnotatorJniContext* model_context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
- const std::string context_utf8 = ToStlString(env, context);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
const CodepointSpan input_indices =
ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end});
- const libtextclassifier3::ClassificationOptions classification_options =
- FromJavaClassificationOptions(env, options);
+ TC3_ASSIGN_OR_RETURN_NULL(
+ const libtextclassifier3::ClassificationOptions classification_options,
+ FromJavaClassificationOptions(env, options));
const std::vector<ClassificationResult> classification_result =
model_context->model()->ClassifyText(context_utf8, input_indices,
classification_options);
+
+ ScopedLocalRef<jobjectArray> result;
if (app_context != nullptr) {
- return ClassificationResultsWithIntentsToJObjectArray(
- env, model_context, app_context, device_locales,
- &classification_options, context_utf8, input_indices,
- classification_result,
- /*generate_intents=*/true);
+ TC3_ASSIGN_OR_RETURN_NULL(
+ result, ClassificationResultsWithIntentsToJObjectArray(
+ env, model_context, app_context, device_locales,
+ &classification_options, context_utf8, input_indices,
+ classification_result,
+ /*generate_intents=*/true));
+
+ } else {
+ TC3_ASSIGN_OR_RETURN_NULL(
+ result, ClassificationResultsToJObjectArray(env, model_context,
+ classification_result));
}
- return ClassificationResultsToJObjectArray(env, model_context,
- classification_result);
+
+ return result.release();
}
TC3_JNI_METHOD(jobjectArray, TC3_ANNOTATOR_CLASS_NAME, nativeAnnotate)
@@ -580,41 +639,46 @@
}
const AnnotatorJniContext* model_context =
reinterpret_cast<AnnotatorJniContext*>(ptr);
- const std::string context_utf8 = ToStlString(env, context);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string context_utf8,
+ ToStlString(env, context));
+ TC3_ASSIGN_OR_RETURN_NULL(
+ libtextclassifier3::AnnotationOptions annotation_options,
+ FromJavaAnnotationOptions(env, options));
const std::vector<AnnotatedSpan> annotations =
- model_context->model()->Annotate(context_utf8,
- FromJavaAnnotationOptions(env, options));
+ model_context->model()->Annotate(context_utf8, annotation_options);
- jclass result_class = env->FindClass(
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan");
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find result class: "
- << TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$AnnotatedSpan";
- return nullptr;
- }
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotatedSpan"));
jmethodID result_class_constructor =
- env->GetMethodID(result_class, "<init>",
+ env->GetMethodID(result_class.get(), "<init>",
"(II[L" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
"$ClassificationResult;)V");
- jobjectArray results =
- env->NewObjectArray(annotations.size(), result_class, nullptr);
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, annotations.size(), result_class.get()));
for (int i = 0; i < annotations.size(); ++i) {
CodepointSpan span_bmp =
ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span);
- jobject result = env->NewObject(
- result_class, result_class_constructor,
- static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second),
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> classification_results,
ClassificationResultsToJObjectArray(env, model_context,
annotations[i].classification));
- env->SetObjectArrayElement(results, i, result);
- env->DeleteLocalRef(result);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(env, result_class.get(), result_class_constructor,
+ static_cast<jint>(span_bmp.first),
+ static_cast<jint>(span_bmp.second),
+ classification_results.get()));
+ env->SetObjectArrayElement(results.get(), i, result.get());
}
- env->DeleteLocalRef(result_class);
- return results;
+ return results.release();
}
TC3_JNI_METHOD(jbyteArray, TC3_ANNOTATOR_CLASS_NAME,
@@ -624,16 +688,19 @@
return nullptr;
}
const Annotator* model = reinterpret_cast<AnnotatorJniContext*>(ptr)->model();
- const std::string id_utf8 = ToStlString(env, id);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string id_utf8, ToStlString(env, id));
std::string serialized_knowledge_result;
if (!model->LookUpKnowledgeEntity(id_utf8, &serialized_knowledge_result)) {
return nullptr;
}
- jbyteArray result = env->NewByteArray(serialized_knowledge_result.size());
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jbyteArray> result,
+ JniHelper::NewByteArray(env, serialized_knowledge_result.size()));
env->SetByteArrayRegion(
- result, 0, serialized_knowledge_result.size(),
+ result.get(), 0, serialized_knowledge_result.size(),
reinterpret_cast<const jbyte*>(serialized_knowledge_result.data()));
- return result;
+ return result.release();
}
TC3_JNI_METHOD(void, TC3_ANNOTATOR_CLASS_NAME, nativeCloseAnnotator)
@@ -654,14 +721,18 @@
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
- return GetLocalesFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetLocalesFromMmap(env, mmap.get()));
+ return value.release();
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetLocalesWithOffset)
(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
- return GetLocalesFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetLocalesFromMmap(env, mmap.get()));
+ return value.release();
}
TC3_JNI_METHOD(jint, TC3_ANNOTATOR_CLASS_NAME, nativeGetVersion)
@@ -682,12 +753,16 @@
(JNIEnv* env, jobject clazz, jint fd) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd));
- return GetNameFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetNameFromMmap(env, mmap.get()));
+ return value.release();
}
TC3_JNI_METHOD(jstring, TC3_ANNOTATOR_CLASS_NAME, nativeGetNameWithOffset)
(JNIEnv* env, jobject thiz, jint fd, jlong offset, jlong size) {
const std::unique_ptr<libtextclassifier3::ScopedMmap> mmap(
new libtextclassifier3::ScopedMmap(fd, offset, size));
- return GetNameFromMmap(env, mmap.get());
+ TC3_ASSIGN_OR_RETURN_NULL(ScopedLocalRef<jstring> value,
+ GetNameFromMmap(env, mmap.get()));
+ return value.release();
}
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index 55f14e6..575e71b 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -17,138 +17,176 @@
#include "annotator/annotator_jni_common.h"
#include "utils/java/jni-base.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/java/jni-helper.h"
namespace libtextclassifier3 {
namespace {
-std::unordered_set<std::string> EntityTypesFromJObject(JNIEnv* env,
- const jobject& jobject) {
+StatusOr<std::unordered_set<std::string>> EntityTypesFromJObject(
+ JNIEnv* env, const jobject& jobject) {
std::unordered_set<std::string> entity_types;
jobjectArray jentity_types = reinterpret_cast<jobjectArray>(jobject);
const int size = env->GetArrayLength(jentity_types);
for (int i = 0; i < size; ++i) {
- jstring jentity_type =
- reinterpret_cast<jstring>(env->GetObjectArrayElement(jentity_types, i));
- entity_types.insert(ToStlString(env, jentity_type));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> jentity_type,
+ JniHelper::GetObjectArrayElement<jstring>(env, jentity_types, i));
+ TC3_ASSIGN_OR_RETURN(std::string entity_type,
+ ToStlString(env, jentity_type.get()));
+ entity_types.insert(entity_type);
}
return entity_types;
}
template <typename T>
-T FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
- const std::string& class_name) {
+StatusOr<T> FromJavaOptionsInternal(JNIEnv* env, jobject joptions,
+ const std::string& class_name) {
if (!joptions) {
- return {};
+ return {Status::UNKNOWN};
}
- const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()),
- env);
- if (!options_class) {
- return {};
- }
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, class_name.c_str()));
- const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getLocale", "Ljava/lang/String;");
- const std::pair<bool, jobject> status_or_reference_timezone =
- CallJniMethod0<jobject>(env, joptions, options_class.get(),
- &JNIEnv::CallObjectMethod, "getReferenceTimezone",
- "Ljava/lang/String;");
- const std::pair<bool, int64> status_or_reference_time_ms_utc =
- CallJniMethod0<int64>(env, joptions, options_class.get(),
- &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc",
- "J");
- const std::pair<bool, jobject> status_or_detected_text_language_tags =
- CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getDetectedTextLanguageTags", "Ljava/lang/String;");
- const std::pair<bool, int> status_or_annotation_usecase =
- CallJniMethod0<int>(env, joptions, options_class.get(),
- &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
+ // .getLocale()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_locale,
+ JniHelper::GetMethodID(env, options_class.get(), "getLocale",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> locales,
+ JniHelper::CallObjectMethod<jstring>(env, joptions, get_locale));
- if (!status_or_locales.first || !status_or_reference_timezone.first ||
- !status_or_reference_time_ms_utc.first ||
- !status_or_detected_text_language_tags.first ||
- !status_or_annotation_usecase.first) {
- return {};
- }
+ // .getReferenceTimeMsUtc()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_reference_time_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getReferenceTimeMsUtc", "()J"));
+ TC3_ASSIGN_OR_RETURN(
+ int64 reference_time,
+ JniHelper::CallLongMethod(env, joptions, get_reference_time_method));
+
+ // .getReferenceTimezone()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_reference_timezone_method,
+ JniHelper::GetMethodID(env, options_class.get(), "getReferenceTimezone",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reference_timezone,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_reference_timezone_method));
+
+ // .getDetectedTextLanguageTags()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_detected_text_language_tags_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getDetectedTextLanguageTags",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> detected_text_language_tags,
+ JniHelper::CallObjectMethod<jstring>(
+ env, joptions, get_detected_text_language_tags_method));
+
+ // .getAnnotationUsecase()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotationUsecase", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotation_usecase,
+ JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
T options;
- options.locales =
- ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
- options.reference_timezone = ToStlString(
- env, reinterpret_cast<jstring>(status_or_reference_timezone.second));
- options.reference_time_ms_utc = status_or_reference_time_ms_utc.second;
- options.detected_text_language_tags = ToStlString(
- env,
- reinterpret_cast<jstring>(status_or_detected_text_language_tags.second));
+ TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
+ TC3_ASSIGN_OR_RETURN(options.reference_timezone,
+ ToStlString(env, reference_timezone.get()));
+ options.reference_time_ms_utc = reference_time;
+ TC3_ASSIGN_OR_RETURN(options.detected_text_language_tags,
+ ToStlString(env, detected_text_language_tags.get()));
options.annotation_usecase =
- static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
+ static_cast<AnnotationUsecase>(annotation_usecase);
return options;
}
} // namespace
-SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) {
+StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env,
+ jobject joptions) {
if (!joptions) {
- return {};
+ return {Status::UNKNOWN};
}
- const ScopedLocalRef<jclass> options_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$SelectionOptions"),
- env);
- const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>(
- env, joptions, options_class.get(), &JNIEnv::CallObjectMethod,
- "getLocales", "Ljava/lang/String;");
- const std::pair<bool, int> status_or_annotation_usecase =
- CallJniMethod0<int>(env, joptions, options_class.get(),
- &JNIEnv::CallIntMethod, "getAnnotationUsecase", "I");
- if (!status_or_locales.first || !status_or_annotation_usecase.first) {
- return {};
- }
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$SelectionOptions"));
+
+ // .getLocale()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_locales,
+ JniHelper::GetMethodID(env, options_class.get(), "getLocales",
+ "()Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> locales,
+ JniHelper::CallObjectMethod<jstring>(env, joptions, get_locales));
+
+ // .getAnnotationUsecase()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_annotation_usecase,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "getAnnotationUsecase", "()I"));
+ TC3_ASSIGN_OR_RETURN(
+ int32 annotation_usecase,
+ JniHelper::CallIntMethod(env, joptions, get_annotation_usecase));
SelectionOptions options;
- options.locales =
- ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second));
+ TC3_ASSIGN_OR_RETURN(options.locales, ToStlString(env, locales.get()));
options.annotation_usecase =
- static_cast<AnnotationUsecase>(status_or_annotation_usecase.second);
+ static_cast<AnnotationUsecase>(annotation_usecase);
return options;
}
-ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
- jobject joptions) {
+StatusOr<ClassificationOptions> FromJavaClassificationOptions(
+ JNIEnv* env, jobject joptions) {
return FromJavaOptionsInternal<ClassificationOptions>(
env, joptions,
TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$ClassificationOptions");
}
-AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) {
- if (!joptions) return {};
- const ScopedLocalRef<jclass> options_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$AnnotationOptions"),
- env);
- if (!options_class) return {};
- const std::pair<bool, jobject> status_or_entity_types =
- CallJniMethod0<jobject>(env, joptions, options_class.get(),
- &JNIEnv::CallObjectMethod, "getEntityTypes",
- "[Ljava/lang/String;");
- if (!status_or_entity_types.first) return {};
- const std::pair<bool, bool> status_or_enable_serialized_entity_data =
- CallJniMethod0<bool>(env, joptions, options_class.get(),
- &JNIEnv::CallBooleanMethod,
- "isSerializedEntityDataEnabled", "Z");
- if (!status_or_enable_serialized_entity_data.first) return {};
- AnnotationOptions annotation_options =
+StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env,
+ jobject joptions) {
+ if (!joptions) {
+ return {Status::UNKNOWN};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jclass> options_class,
+ JniHelper::FindClass(env, TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
+ "$AnnotationOptions"));
+
+ // .getEntityTypes()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID get_entity_types,
+ JniHelper::GetMethodID(env, options_class.get(), "getEntityTypes",
+ "()[Ljava/lang/String;"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> entity_types,
+ JniHelper::CallObjectMethod<jobject>(env, joptions, get_entity_types));
+
+ // .isSerializedEntityDataEnabled()
+ TC3_ASSIGN_OR_RETURN(
+ jmethodID is_serialized_entity_data_enabled_method,
+ JniHelper::GetMethodID(env, options_class.get(),
+ "isSerializedEntityDataEnabled", "()Z"));
+ TC3_ASSIGN_OR_RETURN(
+ bool is_serialized_entity_data_enabled,
+ JniHelper::CallBooleanMethod(env, joptions,
+ is_serialized_entity_data_enabled_method));
+
+ TC3_ASSIGN_OR_RETURN(
+ AnnotationOptions annotation_options,
FromJavaOptionsInternal<AnnotationOptions>(
env, joptions,
- TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions");
- annotation_options.entity_types =
- EntityTypesFromJObject(env, status_or_entity_types.second);
+ TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR "$AnnotationOptions"));
+ TC3_ASSIGN_OR_RETURN(annotation_options.entity_types,
+ EntityTypesFromJObject(env, entity_types.get()));
annotation_options.is_serialized_entity_data_enabled =
- status_or_enable_serialized_entity_data.second;
+ is_serialized_entity_data_enabled;
return annotation_options;
}
diff --git a/native/annotator/annotator_jni_common.h b/native/annotator/annotator_jni_common.h
index b62bb21..f1f1d88 100644
--- a/native/annotator/annotator_jni_common.h
+++ b/native/annotator/annotator_jni_common.h
@@ -20,6 +20,7 @@
#include <jni.h>
#include "annotator/annotator.h"
+#include "utils/base/statusor.h"
#ifndef TC3_ANNOTATOR_CLASS_NAME
#define TC3_ANNOTATOR_CLASS_NAME AnnotatorModel
@@ -29,12 +30,14 @@
namespace libtextclassifier3 {
-SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions);
-
-ClassificationOptions FromJavaClassificationOptions(JNIEnv* env,
+StatusOr<SelectionOptions> FromJavaSelectionOptions(JNIEnv* env,
jobject joptions);
-AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions);
+StatusOr<ClassificationOptions> FromJavaClassificationOptions(JNIEnv* env,
+ jobject joptions);
+
+StatusOr<AnnotationOptions> FromJavaAnnotationOptions(JNIEnv* env,
+ jobject joptions);
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/parser.cc
index 0f222bd..6c759e7 100644
--- a/native/annotator/datetime/parser.cc
+++ b/native/annotator/datetime/parser.cc
@@ -92,7 +92,7 @@
}
if (model->locales() != nullptr) {
- for (int i = 0; i < model->locales()->Length(); ++i) {
+ for (int i = 0; i < model->locales()->size(); ++i) {
locale_string_to_id_[model->locales()->Get(i)->str()] = i;
}
}
@@ -106,6 +106,8 @@
use_extractors_for_locating_ = model->use_extractors_for_locating();
generate_alternative_interpretations_when_ambiguous_ =
model->generate_alternative_interpretations_when_ambiguous();
+ prefer_future_for_unspecified_date_ =
+ model->prefer_future_for_unspecified_date();
initialized_ = true;
}
@@ -433,7 +435,8 @@
// response. For Details see b/130355975
if (!calendarlib_.InterpretParseData(
interpretation, reference_time_ms_utc, reference_timezone,
- reference_locale, &(result.time_ms_utc), &(result.granularity))) {
+ reference_locale, prefer_future_for_unspecified_date_,
+ &(result.time_ms_utc), &(result.granularity))) {
return false;
}
diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h
index 4e995bd..a5192d3 100644
--- a/native/annotator/datetime/parser.h
+++ b/native/annotator/datetime/parser.h
@@ -59,12 +59,6 @@
bool anchor_start_end,
std::vector<DatetimeParseResultSpan>* results) const;
-#ifdef TC3_TEST_ONLY
- void TestOnlySetGenerateAlternativeInterpretationsWhenAmbiguous(bool value) {
- generate_alternative_interpretations_when_ambiguous_ = value;
- }
-#endif // TC3_TEST_ONLY
-
protected:
DatetimeParser(const DatetimeModel* model, const UniLib& unilib,
const CalendarLib& calendarlib,
@@ -126,6 +120,7 @@
std::vector<int> default_locale_ids_;
bool use_extractors_for_locating_;
bool generate_alternative_interpretations_when_ambiguous_;
+ bool prefer_future_for_unspecified_date_;
};
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/parser_test.cc
index 35c725f..1ddcf50 100644
--- a/native/annotator/datetime/parser_test.cc
+++ b/native/annotator/datetime/parser_test.cc
@@ -14,20 +14,21 @@
* limitations under the License.
*/
+#include "annotator/datetime/parser.h"
+
#include <time.h>
+
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
#include "annotator/annotator.h"
-#include "annotator/datetime/parser.h"
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
#include "utils/testing/annotator.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
using std::vector;
using testing::ElementsAreArray;
@@ -152,7 +153,7 @@
const int expected_start_index =
std::distance(marked_text_unicode.begin(), brace_open_it);
- // The -1 bellow is to account for the opening bracket character.
+ // The -1 below is to account for the opening bracket character.
const int expected_end_index =
std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
@@ -746,6 +747,43 @@
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_SMART));
}
+TEST_F(ParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past, and so the parser should move this to the next day ->
+ // "Fri Jan 02 1970 00:30:00" Zurich time (b/139112907).
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", 84600000L /* 23.5 hours from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
+TEST_F(ParserTest, DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
+ // ParsesCorrectly uses 0 as the reference time, which corresponds to:
+ // "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
+ // it is in the past. The parameter prefer_future_when_unspecified_day is
+ // disabled, so the parser should annotate this to the same day: "Thu Jan 01
+ // 1970 00:30:00" Zurich time.
+ LoadModel([](ModelT* model) {
+ // In the test model, the prefer_future_for_unspecified_date is true; make
+ // it false only for this test.
+ model->datetime_model->prefer_future_for_unspecified_date = false;
+ });
+
+ EXPECT_TRUE(ParsesCorrectly(
+ "{0:30am}", -1800000L /* -30 minutes from reference time */,
+ GRANULARITY_MINUTE,
+ {DatetimeComponentsBuilder()
+ .Add(DatetimeComponent::ComponentType::MERIDIEM, 0)
+ .Add(DatetimeComponent::ComponentType::MINUTE, 30)
+ .Add(DatetimeComponent::ComponentType::HOUR, 0)
+ .Build()}));
+}
+
TEST_F(ParserTest, ParsesNoonAndMidnightCorrectly) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
diff --git a/native/annotator/duration/duration.cc b/native/annotator/duration/duration.cc
index 3529691..907a1a4 100644
--- a/native/annotator/duration/duration.cc
+++ b/native/annotator/duration/duration.cc
@@ -23,6 +23,7 @@
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/strings/numbers.h"
+#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -31,46 +32,55 @@
namespace internal {
namespace {
+std::string ToLowerString(const std::string& str, const UniLib* unilib) {
+ return unilib->ToLowerText(UTF8ToUnicodeText(str, /*do_copy=*/false))
+ .ToUTF8String();
+}
+
void FillDurationUnitMap(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
expressions,
DurationUnit duration_unit,
- std::unordered_map<std::string, DurationUnit>* target_map) {
+ std::unordered_map<std::string, DurationUnit>* target_map,
+ const UniLib* unilib) {
if (expressions == nullptr) {
return;
}
for (const flatbuffers::String* expression_string : *expressions) {
- (*target_map)[expression_string->c_str()] = duration_unit;
+ (*target_map)[ToLowerString(expression_string->c_str(), unilib)] =
+ duration_unit;
}
}
} // namespace
std::unordered_map<std::string, DurationUnit> BuildTokenToDurationUnitMapping(
- const DurationAnnotatorOptions* options) {
+ const DurationAnnotatorOptions* options, const UniLib* unilib) {
std::unordered_map<std::string, DurationUnit> mapping;
- FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK,
- &mapping);
- FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping);
- FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR,
- &mapping);
+ FillDurationUnitMap(options->week_expressions(), DurationUnit::WEEK, &mapping,
+ unilib);
+ FillDurationUnitMap(options->day_expressions(), DurationUnit::DAY, &mapping,
+ unilib);
+ FillDurationUnitMap(options->hour_expressions(), DurationUnit::HOUR, &mapping,
+ unilib);
FillDurationUnitMap(options->minute_expressions(), DurationUnit::MINUTE,
- &mapping);
+ &mapping, unilib);
FillDurationUnitMap(options->second_expressions(), DurationUnit::SECOND,
- &mapping);
+ &mapping, unilib);
return mapping;
}
std::unordered_set<std::string> BuildStringSet(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
- strings) {
+ strings,
+ const UniLib* unilib) {
std::unordered_set<std::string> result;
if (strings == nullptr) {
return result;
}
for (const flatbuffers::String* string_value : *strings) {
- result.insert(string_value->c_str());
+ result.insert(ToLowerString(string_value->c_str(), unilib));
}
return result;
@@ -260,14 +270,17 @@
std::string token_value_buffer;
const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
- if (half_expressions_.find(token_value) != half_expressions_.end()) {
+ if (half_expressions_.find(lowercase_token_value) !=
+ half_expressions_.end()) {
value->plus_half = true;
return true;
}
int32 parsed_value;
- if (ParseInt32(token_value.c_str(), &parsed_value)) {
+ if (ParseInt32(lowercase_token_value.c_str(), &parsed_value)) {
value->value = parsed_value;
return true;
}
@@ -280,8 +293,10 @@
std::string token_value_buffer;
const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
- const auto it = token_value_to_duration_unit_.find(token_value);
+ const auto it = token_value_to_duration_unit_.find(lowercase_token_value);
if (it == token_value_to_duration_unit_.end()) {
return false;
}
@@ -319,8 +334,11 @@
std::string token_value_buffer;
const std::string& token_value = feature_processor_->StripBoundaryCodepoints(
token.value, &token_value_buffer);
+ const std::string& lowercase_token_value =
+ internal::ToLowerString(token_value, unilib_);
- if (filler_expressions_.find(token_value) == filler_expressions_.end()) {
+ if (filler_expressions_.find(lowercase_token_value) ==
+ filler_expressions_.end()) {
return false;
}
diff --git a/native/annotator/duration/duration.h b/native/annotator/duration/duration.h
index 2242259..db4bdae 100644
--- a/native/annotator/duration/duration.h
+++ b/native/annotator/duration/duration.h
@@ -26,6 +26,7 @@
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
namespace libtextclassifier3 {
@@ -46,12 +47,14 @@
// Prepares the mapping between token values and duration unit types.
std::unordered_map<std::string, internal::DurationUnit>
-BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options);
+BuildTokenToDurationUnitMapping(const DurationAnnotatorOptions* options,
+ const UniLib* unilib);
// Creates a set of strings from a flatbuffer string vector.
std::unordered_set<std::string> BuildStringSet(
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>*
- strings);
+ strings,
+ const UniLib* unilib);
// Creates a set of ints from a flatbuffer int vector.
std::unordered_set<int32> BuildInt32Set(const flatbuffers::Vector<int32>* ints);
@@ -62,15 +65,17 @@
class DurationAnnotator {
public:
explicit DurationAnnotator(const DurationAnnotatorOptions* options,
- const FeatureProcessor* feature_processor)
+ const FeatureProcessor* feature_processor,
+ const UniLib* unilib)
: options_(options),
feature_processor_(feature_processor),
+ unilib_(unilib),
token_value_to_duration_unit_(
- internal::BuildTokenToDurationUnitMapping(options)),
+ internal::BuildTokenToDurationUnitMapping(options, unilib)),
filler_expressions_(
- internal::BuildStringSet(options->filler_expressions())),
+ internal::BuildStringSet(options->filler_expressions(), unilib)),
half_expressions_(
- internal::BuildStringSet(options->half_expressions())),
+ internal::BuildStringSet(options->half_expressions(), unilib)),
sub_token_separator_codepoints_(internal::BuildInt32Set(
options->sub_token_separator_codepoints())) {}
@@ -125,6 +130,7 @@
const DurationAnnotatorOptions* options_;
const FeatureProcessor* feature_processor_;
+ const UniLib* unilib_;
const std::unordered_map<std::string, internal::DurationUnit>
token_value_to_duration_unit_;
const std::unordered_set<std::string> filler_expressions_;
diff --git a/native/annotator/duration/duration_test.cc b/native/annotator/duration/duration_test.cc
index 3fc25e6..d1dc67a 100644
--- a/native/annotator/duration/duration_test.cc
+++ b/native/annotator/duration/duration_test.cc
@@ -106,7 +106,7 @@
: INIT_UNILIB_FOR_TESTING(unilib_),
feature_processor_(BuildFeatureProcessor(&unilib_)),
duration_annotator_(TestingDurationAnnotatorOptions(),
- feature_processor_.get()) {}
+ feature_processor_.get(), &unilib_) {}
std::vector<Token> Tokenize(const UnicodeText& text) {
return feature_processor_->Tokenize(text);
@@ -195,6 +195,26 @@
3 * 60 * 60 * 1000 + 5 * 1000)))))));
}
+TEST_F(DurationAnnotatorTest, AllUnitsAreCovered) {
+ const UnicodeText text = UTF8ToUnicodeText(
+ "See you in a week and a day and an hour and a minute and a second");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(13, 65)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 7 * 24 * 60 * 60 * 1000 + 24 * 60 * 60 * 1000 +
+ 60 * 60 * 1000 + 60 * 1000 + 1000)))))));
+}
+
TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
std::vector<Token> tokens = Tokenize(text);
@@ -350,5 +370,62 @@
1400L * 60L * 60L * 1000L)));
}
+TEST_F(DurationAnnotatorTest, FindsSimpleDurationIgnoringCase) {
+ const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 MiNuTeS ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 15 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpressionIgnoringCase) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 and HaLf minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
+TEST_F(DurationAnnotatorTest,
+ FindsDurationWithHalfExpressionIgnoringFillerWordCase) {
+ const UnicodeText text =
+ UTF8ToUnicodeText("Set a timer for 3 AnD half minutes ok?");
+ std::vector<Token> tokens = Tokenize(text);
+ std::vector<AnnotatedSpan> result;
+ EXPECT_TRUE(duration_annotator_.FindAll(
+ text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
+
+ EXPECT_THAT(
+ result,
+ ElementsAre(
+ AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
+ Field(&AnnotatedSpan::classification,
+ ElementsAre(AllOf(
+ Field(&ClassificationResult::collection, "duration"),
+ Field(&ClassificationResult::duration_ms,
+ 3.5 * 60 * 1000)))))));
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/annotator/entity-data.fbs b/native/annotator/entity-data.fbs
index 6da3dd5..fa2dc0b 100755
--- a/native/annotator/entity-data.fbs
+++ b/native/annotator/entity-data.fbs
@@ -125,6 +125,46 @@
flight_number:string (shared);
}
+// Details about an ISBN number.
+namespace libtextclassifier3.EntityData_;
+table Isbn {
+ // The (normalized) number.
+ number:string (shared);
+}
+
+// Details about an IBAN number.
+namespace libtextclassifier3.EntityData_;
+table Iban {
+ // The (normalized) number.
+ number:string (shared);
+
+ // The country code.
+ country_code:string (shared);
+}
+
+namespace libtextclassifier3.EntityData_.ParcelTracking_;
+enum Carrier : int {
+ UNKNOWN_CARRIER = 0,
+ FEDEX = 1,
+ UPS = 2,
+ DHL = 3,
+ USPS = 4,
+ ONTRAC = 5,
+ LASERSHIP = 6,
+ ISRAEL_POST = 7,
+ SWISS_POST = 8,
+ MSC = 9,
+ AMAZON = 10,
+ I_PARCEL = 11,
+}
+
+// Details about a tracking number.
+namespace libtextclassifier3.EntityData_;
+table ParcelTracking {
+ carrier:ParcelTracking_.Carrier;
+ tracking_number:string (shared);
+}
+
// Represents an entity annotated in text.
namespace libtextclassifier3;
table EntityData {
@@ -143,6 +183,9 @@
app:EntityData_.App;
payment_card:EntityData_.PaymentCard;
flight:EntityData_.Flight;
+ isbn:EntityData_.Isbn;
+ iban:EntityData_.Iban;
+ parcel:EntityData_.ParcelTracking;
}
root_type libtextclassifier3.EntityData;
diff --git a/native/annotator/flatbuffer-utils.cc b/native/annotator/flatbuffer-utils.cc
new file mode 100644
index 0000000..14b5901
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.cc
@@ -0,0 +1,65 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/flatbuffer-utils.h"
+
+#include <memory>
+
+#include "utils/base/logging.h"
+#include "utils/flatbuffers.h"
+#include "flatbuffers/reflection.h"
+
+namespace libtextclassifier3 {
+
+bool SwapFieldNamesForOffsetsInPath(ModelT* model) {
+ if (model->regex_model == nullptr || model->entity_data_schema.empty()) {
+ // Nothing to do.
+ return true;
+ }
+ const reflection::Schema* schema =
+ LoadAndVerifyFlatbuffer<reflection::Schema>(
+ model->entity_data_schema.data(), model->entity_data_schema.size());
+
+ for (std::unique_ptr<RegexModel_::PatternT>& pattern :
+ model->regex_model->patterns) {
+ for (std::unique_ptr<RegexModel_::Pattern_::CapturingGroupT>& group :
+ pattern->capturing_group) {
+ if (group->entity_field_path == nullptr) {
+ continue;
+ }
+
+ if (!SwapFieldNamesForOffsetsInPath(schema,
+ group->entity_field_path.get())) {
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model) {
+ std::unique_ptr<ModelT> unpacked_model = UnPackModel(model.c_str());
+ TC3_CHECK(unpacked_model != nullptr);
+ TC3_CHECK(SwapFieldNamesForOffsetsInPath(unpacked_model.get()));
+ flatbuffers::FlatBufferBuilder builder;
+ FinishModelBuffer(builder, Model::Pack(builder, unpacked_model.get()));
+ return std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
+ builder.GetSize());
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/flatbuffer-utils.h b/native/annotator/flatbuffer-utils.h
new file mode 100644
index 0000000..a7e5d64
--- /dev/null
+++ b/native/annotator/flatbuffer-utils.h
@@ -0,0 +1,38 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility functions for working with FlatBuffers in the annotator model.
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
+
+#include <string>
+
+#include "annotator/model_generated.h"
+
+namespace libtextclassifier3 {
+
+// Resolves field lookups by name to the concrete field offsets in the regex
+// rules of the model.
+bool SwapFieldNamesForOffsetsInPath(ModelT* model);
+
+// Same as above but for a serialized model.
+std::string SwapFieldNamesForOffsetsInPathInSerializedModel(
+ const std::string& model);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_FLATBUFFER_UTILS_H_
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 1787353..865bf85 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -29,6 +29,8 @@
public:
bool Initialize(const std::string& serialized_config) { return true; }
+ void SetPriorityScore(float priority_score) {}
+
bool ClassifyText(const std::string& context, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 181a8aa..5bf1472 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -17,6 +17,7 @@
include "utils/codepoint-range.fbs";
include "utils/flatbuffers.fbs";
include "utils/intents/intent-config.fbs";
+include "utils/normalization.fbs";
include "utils/resources.fbs";
include "utils/tokenizer.fbs";
include "utils/zlib/buffer.fbs";
@@ -209,6 +210,9 @@
// If set, the serialized entity data will be merged with the
// classification result entity data.
serialized_entity_data:string (shared);
+
+ // If set, normalization to apply before text is used in entity data.
+ normalization_options:NormalizationOptions;
}
// List of regular expression matchers to check.
@@ -329,6 +333,9 @@
// If true, will compile the regexes only on first use.
lazy_regex_compilation:bool = true;
+
+ // If true, will give only future dates (when the day is not specified).
+ prefer_future_for_unspecified_date:bool = false;
}
namespace libtextclassifier3.DatetimeModelLibrary_;
@@ -363,6 +370,9 @@
// Priority score assigned to the "other" class from ML model.
other_collection_priority_score:float = -1000;
+
+ // Priority score assigned to knowledge engine annotations.
+ knowledge_priority_score:float = 0;
}
// Options controlling the output of the classifier.
@@ -675,6 +685,10 @@
// Priority score for the percentage annotation.
percentage_priority_score:float = 1;
+
+ // Float number priority score used for conflict resolution with the other
+ // models.
+ float_number_priority_score:float = 0;
}
// DurationAnnotator is so far tailored for English only.
diff --git a/native/annotator/number/number.cc b/native/annotator/number/number.cc
index 7af63fa..671e1af 100644
--- a/native/annotator/number/number.cc
+++ b/native/annotator/number/number.cc
@@ -20,7 +20,9 @@
#include <cstdlib>
#include "annotator/collections.h"
+#include "annotator/types.h"
#include "utils/base/logging.h"
+#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -28,68 +30,38 @@
const UnicodeText& context, CodepointSpan selection_indices,
AnnotationUsecase annotation_usecase,
ClassificationResult* classification_result) const {
- if (!options_->enabled() || ((1 << annotation_usecase) &
- options_->enabled_annotation_usecases()) == 0) {
+ TC3_CHECK(classification_result != nullptr);
+
+ const UnicodeText substring_selected = UnicodeText::Substring(
+ context, selection_indices.first, selection_indices.second);
+
+ std::vector<AnnotatedSpan> results;
+ if (!FindAll(substring_selected, annotation_usecase, &results)) {
return false;
}
- int64 parsed_int_value;
- double parsed_double_value;
- int num_prefix_codepoints;
- int num_suffix_codepoints;
- const UnicodeText substring_selected = UnicodeText::Substring(
- context, selection_indices.first, selection_indices.second);
- if (ParseNumber(substring_selected, &parsed_int_value, &parsed_double_value,
- &num_prefix_codepoints, &num_suffix_codepoints)) {
- TC3_CHECK(classification_result != nullptr);
- classification_result->score = options_->score();
- classification_result->priority_score = options_->priority_score();
- classification_result->numeric_value = parsed_int_value;
- classification_result->numeric_double_value = parsed_double_value;
+ const CodepointSpan stripped_selection_indices =
+ feature_processor_->StripBoundaryCodepoints(
+ context, selection_indices, ignored_prefix_span_boundary_codepoints_,
+ ignored_suffix_span_boundary_codepoints_);
- if (num_suffix_codepoints == 0) {
- classification_result->collection = Collections::Number();
- return true;
+ for (const AnnotatedSpan& result : results) {
+ if (result.classification.empty()) {
+ continue;
}
- // If the selection ends in %, the parseNumber returns true with
- // num_suffix_codepoints = 1 => percent
- if (options_->enable_percentage() &&
- GetPercentSuffixLength(
- context, context.size_codepoints(),
- selection_indices.second - num_suffix_codepoints) ==
- num_suffix_codepoints) {
- classification_result->collection = Collections::Percentage();
- classification_result->priority_score =
- options_->percentage_priority_score();
+ // We make sure that the result span is equal to the stripped selection span
+ // to avoid validating cases like "23 asdf 3.14 pct asdf". FindAll will
+ // anyway only find valid numbers and percentages and a given selection with
+ // more than two tokens won't pass this check.
+ if (result.span.first + selection_indices.first ==
+ stripped_selection_indices.first &&
+ result.span.second + selection_indices.first ==
+ stripped_selection_indices.second) {
+ *classification_result = result.classification[0];
return true;
}
- } else if (options_->enable_percentage()) {
- // If the substring selected is a percent matching the form: 5 percent,
- // 5 pct or 5 pc => percent.
- std::vector<AnnotatedSpan> results;
- FindAll(substring_selected, annotation_usecase, &results);
- for (auto& result : results) {
- if (result.classification.empty() ||
- result.classification[0].collection != Collections::Percentage()) {
- continue;
- }
- if (result.span.first == 0 &&
- result.span.second == substring_selected.size_codepoints()) {
- TC3_CHECK(classification_result != nullptr);
- classification_result->collection = Collections::Percentage();
- classification_result->score = options_->score();
- classification_result->priority_score =
- options_->percentage_priority_score();
- classification_result->numeric_value =
- result.classification[0].numeric_value;
- classification_result->numeric_double_value =
- result.classification[0].numeric_double_value;
- return true;
- }
- }
}
-
return false;
}
@@ -107,21 +79,24 @@
UTF8ToUnicodeText(token.value, /*do_copy=*/false);
int64 parsed_int_value;
double parsed_double_value;
+ bool has_decimal;
int num_prefix_codepoints;
int num_suffix_codepoints;
if (ParseNumber(token_text, &parsed_int_value, &parsed_double_value,
- &num_prefix_codepoints, &num_suffix_codepoints)) {
+ &has_decimal, &num_prefix_codepoints,
+ &num_suffix_codepoints)) {
ClassificationResult classification{Collections::Number(),
options_->score()};
classification.numeric_value = parsed_int_value;
classification.numeric_double_value = parsed_double_value;
- classification.priority_score = options_->priority_score();
+ classification.priority_score =
+ has_decimal ? options_->float_number_priority_score()
+ : options_->priority_score();
AnnotatedSpan annotated_span;
annotated_span.span = {token.start + num_prefix_codepoints,
token.end - num_suffix_codepoints};
annotated_span.classification.push_back(classification);
-
result->push_back(annotated_span);
}
}
@@ -151,7 +126,7 @@
namespace {
bool ParseNextNumericCodepoint(int32 codepoint, int64* current_value) {
- if (*current_value > INT64_MAX / 10) {
+ if (*current_value > INT64_MAX / 10 - 10) {
return false;
}
@@ -163,20 +138,20 @@
UnicodeText::const_iterator ConsumeAndParseNumber(
const UnicodeText::const_iterator& it_begin,
const UnicodeText::const_iterator& it_end, int64* int_result,
- double* double_result) {
+ double* double_result, bool* has_decimal) {
*int_result = 0;
+ *has_decimal = false;
// See if there's a sign in the beginning of the number.
int sign = 1;
auto it = it_begin;
- if (it != it_end) {
+ while (it != it_end && (*it == '-' || *it == '+')) {
if (*it == '-') {
- ++it;
sign = -1;
- } else if (*it == '+') {
- ++it;
+ } else {
sign = 1;
}
+ ++it;
}
enum class State {
@@ -203,6 +178,7 @@
break;
case State::PARSING_FLOATING_PART:
if (*it >= '0' && *it <= '9') {
+ *has_decimal = true;
if (!ParseNextNumericCodepoint(*it, &decimal_result)) {
state = State::PARSING_DONE;
break;
@@ -236,7 +212,7 @@
} // namespace
bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* int_result,
- double* double_result,
+ double* double_result, bool* has_decimal,
int* num_prefix_codepoints,
int* num_suffix_codepoints) const {
TC3_CHECK(int_result != nullptr && double_result != nullptr &&
@@ -258,13 +234,6 @@
// Consume prefix codepoints.
*num_prefix_codepoints = stripped_span.first;
- bool valid_prefix = true;
- // Makes valid_prefix=false for cases like: "#10" where it points to '1'. In
- // this case the adjacent prefix is not an allowed one.
- if (it != text.begin() && allowed_prefix_codepoints_.find(*std::prev(it)) ==
- allowed_prefix_codepoints_.end()) {
- valid_prefix = false;
- }
while (it != it_end) {
if (allowed_prefix_codepoints_.find(*it) ==
allowed_prefix_codepoints_.end()) {
@@ -276,7 +245,8 @@
}
auto it_start = it;
- it = ConsumeAndParseNumber(it, it_end, int_result, double_result);
+ it =
+ ConsumeAndParseNumber(it, it_end, int_result, double_result, has_decimal);
if (it == it_start) {
return false;
}
@@ -284,32 +254,35 @@
// Consume suffix codepoints.
bool valid_suffix = true;
*num_suffix_codepoints = 0;
+ int ignored_suffix_codepoints = 0;
while (it != it_end) {
- if (allowed_suffix_codepoints_.find(*it) ==
+ if (allowed_suffix_codepoints_.find(*it) !=
allowed_suffix_codepoints_.end()) {
+ // Keep track of allowed suffix codepoints.
+ ++(*num_suffix_codepoints);
+ } else if (ignored_suffix_span_boundary_codepoints_.find(*it) ==
+ ignored_suffix_span_boundary_codepoints_.end()) {
+ // There is a suffix codepoint but it's not part of the ignored list of
+ // codepoints, fail the number parsing.
+ // Note: We want to support cases like "13.", "34#", "123!" etc.
valid_suffix = false;
break;
+ } else {
+ ++ignored_suffix_codepoints;
}
++it;
- ++(*num_suffix_codepoints);
}
*num_suffix_codepoints += num_stripped_end;
- // Makes valid_suffix=false for cases like: "10@", when it == it_end and
- // points to '@'. This adjacent character is not an allowed suffix.
- if (it == it_end && it != text.end() &&
- allowed_suffix_codepoints_.find(*it) ==
- allowed_suffix_codepoints_.end()) {
- valid_suffix = false;
- }
-
- return valid_suffix && valid_prefix;
+ return valid_suffix;
}
int NumberAnnotator::GetPercentSuffixLength(const UnicodeText& context,
- int context_size_codepoints,
int index_codepoints) const {
+ if (index_codepoints >= context.size_codepoints()) {
+ return -1;
+ }
auto context_it = context.begin();
std::advance(context_it, index_codepoints);
const StringPiece suffix_context(
@@ -329,15 +302,13 @@
void NumberAnnotator::FindPercentages(
const UnicodeText& context, std::vector<AnnotatedSpan>* result) const {
- int context_size_codepoints = context.size_codepoints();
for (auto& res : *result) {
if (res.classification.empty() ||
res.classification[0].collection != Collections::Number()) {
continue;
}
- const int match_length = GetPercentSuffixLength(
- context, context_size_codepoints, res.span.second);
+ const int match_length = GetPercentSuffixLength(context, res.span.second);
if (match_length > 0) {
res.classification[0].collection = Collections::Percentage();
res.classification[0].priority_score =
diff --git a/native/annotator/number/number.h b/native/annotator/number/number.h
index 3debd09..3e9e2c3 100644
--- a/native/annotator/number/number.h
+++ b/native/annotator/number/number.h
@@ -81,16 +81,17 @@
static std::vector<uint32> FlatbuffersIntVectorToStdVector(
const flatbuffers::Vector<int32_t>* ints);
- // Parses the text to an int64 value and returns true if succeeded, otherwise
- // false. Also returns the number of prefix/suffix codepoints that were
- // stripped from the number.
+ // Parses the text to an int64 value and a double value and returns true if
+ // succeeded, otherwise false. Also returns whether the number contains a
+ // decimal and the number of prefix/suffix codepoints that were stripped from
+ // the number.
bool ParseNumber(const UnicodeText& text, int64* int_result,
- double* double_result, int* num_prefix_codepoints,
+ double* double_result, bool* has_decimal,
+ int* num_prefix_codepoints,
int* num_suffix_codepoints) const;
// Get the length of the percent suffix at the specified index in the context.
int GetPercentSuffixLength(const UnicodeText& context,
- int context_size_codepoints,
int index_codepoints) const;
// Checks if the annotated numbers from the context represent percentages.
diff --git a/native/annotator/test_data/test_model.fb b/native/annotator/test_data/test_model.fb
index ce5f72f..bbf730e 100644
--- a/native/annotator/test_data/test_model.fb
+++ b/native/annotator/test_data/test_model.fb
Binary files differ
diff --git a/native/annotator/test_data/wrong_embeddings.fb b/native/annotator/test_data/wrong_embeddings.fb
index efefa3c..135dec0 100644
--- a/native/annotator/test_data/wrong_embeddings.fb
+++ b/native/annotator/test_data/wrong_embeddings.fb
Binary files differ
diff --git a/native/annotator/types.cc b/native/annotator/types.cc
index c31097d..1ec3790 100644
--- a/native/annotator/types.cc
+++ b/native/annotator/types.cc
@@ -56,6 +56,64 @@
}
} // namespace
+std::string ComponentTypeToString(
+ const DatetimeComponent::ComponentType& component_type) {
+ switch (component_type) {
+ case DatetimeComponent::ComponentType::UNSPECIFIED:
+ return "UNSPECIFIED";
+ case DatetimeComponent::ComponentType::YEAR:
+ return "YEAR";
+ case DatetimeComponent::ComponentType::MONTH:
+ return "MONTH";
+ case DatetimeComponent::ComponentType::WEEK:
+ return "WEEK";
+ case DatetimeComponent::ComponentType::DAY_OF_WEEK:
+ return "DAY_OF_WEEK";
+ case DatetimeComponent::ComponentType::DAY_OF_MONTH:
+ return "DAY_OF_MONTH";
+ case DatetimeComponent::ComponentType::HOUR:
+ return "HOUR";
+ case DatetimeComponent::ComponentType::MINUTE:
+ return "MINUTE";
+ case DatetimeComponent::ComponentType::SECOND:
+ return "SECOND";
+ case DatetimeComponent::ComponentType::MERIDIEM:
+ return "MERIDIEM";
+ case DatetimeComponent::ComponentType::ZONE_OFFSET:
+ return "ZONE_OFFSET";
+ case DatetimeComponent::ComponentType::DST_OFFSET:
+ return "DST_OFFSET";
+ default:
+ return "";
+ }
+}
+
+std::string RelativeQualifierToString(
+ const DatetimeComponent::RelativeQualifier& relative_qualifier) {
+ switch (relative_qualifier) {
+ case DatetimeComponent::RelativeQualifier::UNSPECIFIED:
+ return "UNSPECIFIED";
+ case DatetimeComponent::RelativeQualifier::NEXT:
+ return "NEXT";
+ case DatetimeComponent::RelativeQualifier::THIS:
+ return "THIS";
+ case DatetimeComponent::RelativeQualifier::LAST:
+ return "LAST";
+ case DatetimeComponent::RelativeQualifier::NOW:
+ return "NOW";
+ case DatetimeComponent::RelativeQualifier::TOMORROW:
+ return "TOMORROW";
+ case DatetimeComponent::RelativeQualifier::YESTERDAY:
+ return "YESTERDAY";
+ case DatetimeComponent::RelativeQualifier::PAST:
+ return "PAST";
+ case DatetimeComponent::RelativeQualifier::FUTURE:
+ return "FUTURE";
+ default:
+ return "";
+ }
+}
+
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const DatetimeParseResultSpan& value) {
stream << "DatetimeParseResultSpan({" << value.span.first << ", "
@@ -63,7 +121,16 @@
for (const DatetimeParseResult& data : value.data) {
stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* "
<< FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ "
- << data.granularity << "}, ";
+ << data.granularity << ", /*datetime_components=*/ ";
+ for (const DatetimeComponent& datetime_comp : data.datetime_components) {
+ stream << "{/*component_type=*/ "
+ << ComponentTypeToString(datetime_comp.component_type)
+ << " /*relative_qualifier=*/ "
+ << RelativeQualifierToString(datetime_comp.relative_qualifier)
+ << " /*value=*/ " << datetime_comp.value << " /*relative_count=*/ "
+ << datetime_comp.relative_count << "}, ";
+ }
+ stream << "}, ";
}
stream << "})";
return stream;
diff --git a/native/annotator/types.h b/native/annotator/types.h
index ac24e24..9b94c10 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -352,7 +352,7 @@
// Entity data information.
std::string serialized_entity_data;
- const EntityData* entity_data() {
+ const EntityData* entity_data() const {
return LoadAndVerifyFlatbuffer<EntityData>(serialized_entity_data.data(),
serialized_entity_data.size());
}
diff --git a/native/annotator/zlib-utils.cc b/native/annotator/zlib-utils.cc
index ec2392b..c3c2cf1 100644
--- a/native/annotator/zlib-utils.cc
+++ b/native/annotator/zlib-utils.cc
@@ -125,6 +125,15 @@
extractor->compressed_pattern.reset(nullptr);
}
}
+
+ if (model->resources != nullptr) {
+ DecompressResources(model->resources.get());
+ }
+
+ if (model->intent_options != nullptr) {
+ DecompressIntentModel(model->intent_options.get());
+ }
+
return true;
}
diff --git a/native/annotator/zlib-utils_test.cc b/native/annotator/zlib-utils_test.cc
index 7a8d775..363c155 100644
--- a/native/annotator/zlib-utils_test.cc
+++ b/native/annotator/zlib-utils_test.cc
@@ -43,6 +43,37 @@
model.datetime_model->extractors.back()->pattern =
"an example datetime extractor";
+ model.intent_options.reset(new IntentFactoryModelT);
+ model.intent_options->generator.emplace_back(
+ new IntentFactoryModel_::IntentGeneratorT);
+ const std::string intent_generator1 = "lua generator 1";
+ model.intent_options->generator.back()->lua_template_generator =
+ std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end());
+ model.intent_options->generator.emplace_back(
+ new IntentFactoryModel_::IntentGeneratorT);
+ const std::string intent_generator2 = "lua generator 2";
+ model.intent_options->generator.back()->lua_template_generator =
+ std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end());
+
+ // NOTE: The resource strings contain some repetition, so that the compressed
+ // version is smaller than the uncompressed one. Because the compression code
+ // looks at that as well.
+ model.resources.reset(new ResourcePoolT);
+ model.resources->resource_entry.emplace_back(new ResourceEntryT);
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr1.1";
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr1.2";
+ model.resources->resource_entry.emplace_back(new ResourceEntryT);
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr2.1";
+ model.resources->resource_entry.back()->resource.emplace_back(new ResourceT);
+ model.resources->resource_entry.back()->resource.back()->content =
+ "rrrrrrrrrrrrr2.2";
+
// Compress the model.
EXPECT_TRUE(CompressModel(&model));
@@ -51,6 +82,14 @@
EXPECT_TRUE(model.regex_model->patterns[1]->pattern.empty());
EXPECT_TRUE(model.datetime_model->patterns[0]->regexes[0]->pattern.empty());
EXPECT_TRUE(model.datetime_model->extractors[0]->pattern.empty());
+ EXPECT_TRUE(
+ model.intent_options->generator[0]->lua_template_generator.empty());
+ EXPECT_TRUE(
+ model.intent_options->generator[1]->lua_template_generator.empty());
+ EXPECT_TRUE(model.resources->resource_entry[0]->resource[0]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[0]->resource[1]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[1]->resource[0]->content.empty());
+ EXPECT_TRUE(model.resources->resource_entry[1]->resource[1]->content.empty());
// Pack and load the model.
flatbuffers::FlatBufferBuilder builder;
@@ -94,6 +133,20 @@
"an example datetime pattern");
EXPECT_EQ(model.datetime_model->extractors[0]->pattern,
"an example datetime extractor");
+ EXPECT_EQ(
+ model.intent_options->generator[0]->lua_template_generator,
+ std::vector<uint8_t>(intent_generator1.begin(), intent_generator1.end()));
+ EXPECT_EQ(
+ model.intent_options->generator[1]->lua_template_generator,
+ std::vector<uint8_t>(intent_generator2.begin(), intent_generator2.end()));
+ EXPECT_EQ(model.resources->resource_entry[0]->resource[0]->content,
+ "rrrrrrrrrrrrr1.1");
+ EXPECT_EQ(model.resources->resource_entry[0]->resource[1]->content,
+ "rrrrrrrrrrrrr1.2");
+ EXPECT_EQ(model.resources->resource_entry[1]->resource[0]->content,
+ "rrrrrrrrrrrrr2.1");
+ EXPECT_EQ(model.resources->resource_entry[1]->resource[1]->content,
+ "rrrrrrrrrrrrr2.2");
}
} // namespace libtextclassifier3
diff --git a/native/lang_id/lang-id-wrapper.cc b/native/lang_id/lang-id-wrapper.cc
new file mode 100644
index 0000000..c2ab25c
--- /dev/null
+++ b/native/lang_id/lang-id-wrapper.cc
@@ -0,0 +1,96 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "lang_id/lang-id-wrapper.h"
+
+#include <fcntl.h>
+
+#include "lang_id/fb_model/lang-id-from-fb.h"
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+
+namespace langid {
+
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromPath(
+ const std::string& langid_model_path) {
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(langid_model_path);
+ return langid_model;
+}
+
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromDescriptor(
+ const int langid_fd) {
+ std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model =
+ libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor(
+ langid_fd);
+ return langid_model;
+}
+
+std::vector<std::pair<std::string, float>> GetPredictions(
+ const libtextclassifier3::mobile::lang_id::LangId* model, const std::string& text) {
+ std::vector<std::pair<std::string, float>> prediction_results;
+ if (model == nullptr) {
+ return prediction_results;
+ }
+
+ const float noise_threshold =
+ model->GetFloatProperty("text_classifier_langid_noise_threshold", -1.0f);
+
+ // Speed up the things by specifying the max results we want. For example, if
+ // the noise threshold is 0.1, we don't need more than 10 results.
+ const int max_results =
+ noise_threshold < 0.01
+ ? -1 // -1 means FindLanguages returns all predictions
+ : static_cast<int>(1 / noise_threshold) + 1;
+
+ libtextclassifier3::mobile::lang_id::LangIdResult langid_result;
+ model->FindLanguages(text, &langid_result, max_results);
+ for (int i = 0; i < langid_result.predictions.size(); i++) {
+ const auto& prediction = langid_result.predictions[i];
+ if (prediction.second >= noise_threshold && prediction.first != "und") {
+ prediction_results.push_back({prediction.first, prediction.second});
+ }
+ }
+ return prediction_results;
+}
+
+std::string GetLanguageTags(const libtextclassifier3::mobile::lang_id::LangId* model,
+ const std::string& text) {
+ const std::vector<std::pair<std::string, float>>& predictions =
+ GetPredictions(model, text);
+ const float threshold =
+ model->GetFloatProperty("text_classifier_langid_threshold", -1.0f);
+ std::string detected_language_tags = "";
+ bool first_accepted_language = true;
+ for (int i = 0; i < predictions.size(); i++) {
+ const auto& prediction = predictions[i];
+ if (threshold >= 0.f && prediction.second < threshold) {
+ continue;
+ }
+ if (first_accepted_language) {
+ first_accepted_language = false;
+ } else {
+ detected_language_tags += ",";
+ }
+ detected_language_tags += prediction.first;
+ }
+ return detected_language_tags;
+}
+
+} // namespace langid
+
+} // namespace libtextclassifier3
diff --git a/native/lang_id/lang-id-wrapper.h b/native/lang_id/lang-id-wrapper.h
new file mode 100644
index 0000000..4c0104b
--- /dev/null
+++ b/native/lang_id/lang-id-wrapper.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_WRAPPER_H_
+#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_WRAPPER_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "lang_id/lang-id.h"
+
+namespace libtextclassifier3 {
+
+namespace langid {
+
+// Loads the LangId model from a given path.
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromPath(
+ const std::string& path);
+
+// Loads the LangId model from a file descriptor.
+std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> LoadFromDescriptor(
+ const int fd);
+
+// Returns the LangId predictions (locale, confidence) from the given LangId
+// model. The maximum number of predictions returned will be computed internally
+// relatively to the noise threshold.
+std::vector<std::pair<std::string, float>> GetPredictions(
+ const libtextclassifier3::mobile::lang_id::LangId* model, const std::string& text);
+
+// Returns the language tags string from the given LangId model. The language
+// tags will be filtered internally by the LangId threshold.
+std::string GetLanguageTags(const libtextclassifier3::mobile::lang_id::LangId* model,
+ const std::string& text);
+
+} // namespace langid
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_WRAPPER_H_
diff --git a/native/lang_id/lang-id.cc b/native/lang_id/lang-id.cc
index ef82456..92359a9 100644
--- a/native/lang_id/lang-id.cc
+++ b/native/lang_id/lang-id.cc
@@ -141,6 +141,12 @@
LightSentence sentence;
tokenizer_.Tokenize(text, &sentence);
+ // Test input size here, after pre-processing removed irrelevant chars.
+ if (IsTooShort(sentence)) {
+ result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1);
+ return;
+ }
+
// Extract features from the tokenized text.
std::vector<FeatureVector> features =
lang_id_brain_interface_.GetFeaturesNoCaching(&sentence);
@@ -205,6 +211,8 @@
bool Setup(TaskContext *context) {
tokenizer_.Setup(context);
if (!lang_id_brain_interface_.SetupForProcessing(context)) return false;
+
+ min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0);
default_threshold_ =
context->Get("reliability_thresh", kDefaultConfidenceThreshold);
@@ -243,6 +251,16 @@
}
}
+ bool IsTooShort(const LightSentence &sentence) const {
+ int text_size = 0;
+ for (const std::string &token : sentence) {
+ // Each token has the form ^...$: we subtract 2 because we want to count
+ // only the real text, not the chars added by us.
+ text_size += token.size() - 2;
+ }
+ return text_size < min_text_size_in_bytes_;
+ }
+
std::unique_ptr<ModelProvider> model_provider_;
TokenizerForLangId tokenizer_;
@@ -256,6 +274,11 @@
// True if this object is ready to perform language predictions.
bool valid_ = false;
+ // The model returns LangId::kUnknownLanguageCode for input text that has
+ // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces,
+ // digits, and punctuation).
+ int min_text_size_in_bytes_ = 0;
+
// Only predictions with a probability (confidence) above this threshold are
// reported. Otherwise, we report LangId::kUnknownLanguageCode.
float default_threshold_ = kDefaultConfidenceThreshold;
diff --git a/native/lang_id/lang-id_jni.cc b/native/lang_id/lang-id_jni.cc
index 02b388f..30753dc 100644
--- a/native/lang_id/lang-id_jni.cc
+++ b/native/lang_id/lang-id_jni.cc
@@ -17,15 +17,19 @@
#include "lang_id/lang-id_jni.h"
#include <jni.h>
+
#include <type_traits>
#include <vector>
+#include "lang_id/lang-id-wrapper.h"
#include "utils/base/logging.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/java/jni-helper.h"
#include "lang_id/fb_model/lang-id-from-fb.h"
#include "lang_id/lang-id.h"
+using libtextclassifier3::JniHelper;
using libtextclassifier3::ScopedLocalRef;
+using libtextclassifier3::StatusOr;
using libtextclassifier3::ToStlString;
using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile;
using libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFileDescriptor;
@@ -33,36 +37,33 @@
using libtextclassifier3::mobile::lang_id::LangIdResult;
namespace {
-jobjectArray LangIdResultToJObjectArray(JNIEnv* env,
- const LangIdResult& lang_id_result,
- const float significant_threshold) {
- const ScopedLocalRef<jclass> result_class(
- env->FindClass(TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR
- "$LanguageResult"),
- env);
- if (!result_class) {
- TC3_LOG(ERROR) << "Couldn't find LanguageResult class.";
- return nullptr;
- }
- std::vector<std::pair<std::string, float>> predictions;
- std::copy_if(lang_id_result.predictions.begin(),
- lang_id_result.predictions.end(),
- std::back_inserter(predictions),
- [significant_threshold](std::pair<std::string, float> pair) {
- return pair.second >= significant_threshold;
- });
+StatusOr<ScopedLocalRef<jobjectArray>> LangIdResultToJObjectArray(
+ JNIEnv* env,
+ const std::vector<std::pair<std::string, float>>& lang_id_predictions) {
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jclass> result_class,
+ JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_LANG_ID_CLASS_NAME_STR "$LanguageResult"));
- const jmethodID result_class_constructor =
- env->GetMethodID(result_class.get(), "<init>", "(Ljava/lang/String;F)V");
- const jobjectArray results =
- env->NewObjectArray(predictions.size(), result_class.get(), nullptr);
- for (int i = 0; i < predictions.size(); i++) {
- ScopedLocalRef<jobject> result(
- env->NewObject(result_class.get(), result_class_constructor,
- env->NewStringUTF(predictions[i].first.c_str()),
- static_cast<jfloat>(predictions[i].second)));
- env->SetObjectArrayElement(results, i, result.get());
+ TC3_ASSIGN_OR_RETURN(const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(env, result_class.get(), "<init>",
+ "(Ljava/lang/String;F)V"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, lang_id_predictions.size(),
+ result_class.get(), nullptr));
+ for (int i = 0; i < lang_id_predictions.size(); i++) {
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jstring> predicted_language,
+ JniHelper::NewStringUTF(env, lang_id_predictions[i].first.c_str()));
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(
+ env, result_class.get(), result_class_constructor,
+ predicted_language.get(),
+ static_cast<jfloat>(lang_id_predictions[i].second)));
+ env->SetObjectArrayElement(results.get(), i, result.get());
}
return results;
}
@@ -83,7 +84,7 @@
TC3_JNI_METHOD(jlong, TC3_LANG_ID_CLASS_NAME, nativeNewFromPath)
(JNIEnv* env, jobject thiz, jstring path) {
- const std::string path_str = ToStlString(env, path);
+ TC3_ASSIGN_OR_RETURN_0(const std::string path_str, ToStlString(env, path));
std::unique_ptr<LangId> lang_id = GetLangIdFromFlatbufferFile(path_str);
if (!lang_id->is_valid()) {
return reinterpret_cast<jlong>(nullptr);
@@ -98,17 +99,15 @@
return nullptr;
}
- const std::string text_str = ToStlString(env, text);
- const float noise_threshold = GetNoiseThreshold(*model);
- // Speed up the things by specifying the max results we want. For example, if
- // the noise threshold is 0.1, we don't need more than 10 results.
- const int max_results =
- noise_threshold < 0.01
- ? -1 // -1 means FindLanguages returns all predictions
- : static_cast<int>(1 / noise_threshold) + 1;
- LangIdResult result;
- model->FindLanguages(text_str, &result, max_results);
- return LangIdResultToJObjectArray(env, result, noise_threshold);
+ TC3_ASSIGN_OR_RETURN_NULL(const std::string text_str, ToStlString(env, text));
+
+ const std::vector<std::pair<std::string, float>>& prediction_results =
+ libtextclassifier3::langid::GetPredictions(model, text_str);
+
+ TC3_ASSIGN_OR_RETURN_NULL(
+ ScopedLocalRef<jobjectArray> results,
+ LangIdResultToJObjectArray(env, prediction_results));
+ return results.release();
}
TC3_JNI_METHOD(void, TC3_LANG_ID_CLASS_NAME, nativeClose)
@@ -156,3 +155,12 @@
LangId* model = reinterpret_cast<LangId*>(ptr);
return GetNoiseThreshold(*model);
}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
+(JNIEnv* env, jobject thizz, jlong ptr) {
+ if (!ptr) {
+ return 0;
+ }
+ LangId* model = reinterpret_cast<LangId*>(ptr);
+ return model->GetFloatProperty("min_text_size_in_bytes", 0);
+}
diff --git a/native/lang_id/lang-id_jni.h b/native/lang_id/lang-id_jni.h
index b765ad4..219349c 100644
--- a/native/lang_id/lang-id_jni.h
+++ b/native/lang_id/lang-id_jni.h
@@ -57,6 +57,9 @@
TC3_JNI_METHOD(jfloat, TC3_LANG_ID_CLASS_NAME, nativeGetLangIdNoiseThreshold)
(JNIEnv* env, jobject thizz, jlong ptr);
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
+(JNIEnv* env, jobject thizz, jlong ptr);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/models/actions_suggestions.en.model b/native/models/actions_suggestions.en.model
index 90d66ba..480a5ff 100644
--- a/native/models/actions_suggestions.en.model
+++ b/native/models/actions_suggestions.en.model
Binary files differ
diff --git a/native/models/actions_suggestions.universal.model b/native/models/actions_suggestions.universal.model
index 74f9ee5..a285ab0 100644
--- a/native/models/actions_suggestions.universal.model
+++ b/native/models/actions_suggestions.universal.model
Binary files differ
diff --git a/native/models/lang_id.model b/native/models/lang_id.model
index 92f0103..e94dada 100644
--- a/native/models/lang_id.model
+++ b/native/models/lang_id.model
Binary files differ
diff --git a/native/utils/base/macros.h b/native/utils/base/macros.h
index 3517225..0b99fe1 100644
--- a/native/utils/base/macros.h
+++ b/native/utils/base/macros.h
@@ -85,6 +85,29 @@
} while (0)
#endif
+#ifdef __has_builtin
+#define TC3_HAS_BUILTIN(x) __has_builtin(x)
+#else
+#define TC3_HAS_BUILTIN(x) 0
+#endif
+
+// Compilers can be told that a certain branch is not likely to be taken
+// (for instance, a CHECK failure), and use that information in static
+// analysis. Giving it this information can help it optimize for the
+// common case in the absence of better information (ie.
+// -fprofile-arcs).
+//
+// We need to disable this for GPU builds, though, since nvcc8 and older
+// don't recognize `__builtin_expect` as a builtin, and fail compilation.
+#if (!defined(__NVCC__)) && (TC3_HAS_BUILTIN(__builtin_expect) || \
+ (defined(__GNUC__) && __GNUC__ >= 3))
+#define TC3_PREDICT_FALSE(x) (__builtin_expect(x, 0))
+#define TC3_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1))
+#else
+#define TC3_PREDICT_FALSE(x) (x)
+#define TC3_PREDICT_TRUE(x) (x)
+#endif
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_BASE_MACROS_H_
diff --git a/native/utils/base/status.cc b/native/utils/base/status.cc
new file mode 100644
index 0000000..9e758e6
--- /dev/null
+++ b/native/utils/base/status.cc
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/status.h"
+
+namespace libtextclassifier3 {
+
+const Status& Status::OK = *new Status(StatusCode::OK, "");
+const Status& Status::UNKNOWN = *new Status(StatusCode::UNKNOWN, "");
+
+Status::Status() : code_(StatusCode::OK) {}
+Status::Status(StatusCode error, const std::string& message)
+ : code_(error), message_(message) {}
+
+Status& Status::operator=(const Status& other) {
+ code_ = other.code_;
+ message_ = other.message_;
+ return *this;
+}
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Status& status) {
+ stream << status.error_code();
+ return stream;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/base/status.h b/native/utils/base/status.h
new file mode 100644
index 0000000..e2220db
--- /dev/null
+++ b/native/utils/base/status.h
@@ -0,0 +1,76 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_H_
+
+#include <string>
+
+#include "utils/base/logging.h"
+
+namespace libtextclassifier3 {
+
+enum class StatusCode {
+ // Not an error; returned on success
+ OK = 0,
+
+ // Unknown error.
+ UNKNOWN = 2,
+
+ // Client specified an invalid argument.
+ INVALID_ARGUMENT = 3,
+};
+
+// A Status is a combination of an error code and a string message (for non-OK
+// error codes).
+class Status {
+ public:
+ // Creates an OK status
+ Status();
+
+ // Make a Status from the specified error and message.
+ Status(StatusCode error, const std::string& error_message);
+
+ Status& operator=(const Status& other);
+
+ // Some pre-defined Status objects
+ static const Status& OK;
+ static const Status& UNKNOWN;
+
+ // Accessors
+ bool ok() const { return code_ == StatusCode::OK; }
+ int error_code() const { return static_cast<int>(code_); }
+
+ StatusCode CanonicalCode() const { return code_; }
+
+ const std::string& error_message() const { return message_; }
+
+ bool operator==(const Status& x) const;
+ bool operator!=(const Status& x) const;
+
+ std::string ToString() const;
+
+ private:
+ StatusCode code_;
+ std::string message_;
+};
+
+logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
+ const Status& status);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_STATUS_H_
diff --git a/native/utils/base/status_test.cc b/native/utils/base/status_test.cc
new file mode 100644
index 0000000..9e3b5c6
--- /dev/null
+++ b/native/utils/base/status_test.cc
@@ -0,0 +1,73 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/status.h"
+
+#include "utils/base/logging.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(StatusTest, PrintsAbortedStatus) {
+ logging::LoggingStringStream stream;
+ stream << Status::UNKNOWN;
+ EXPECT_EQ(Status::UNKNOWN.error_code(), 2);
+ EXPECT_EQ(Status::UNKNOWN.CanonicalCode(), StatusCode::UNKNOWN);
+ EXPECT_EQ(Status::UNKNOWN.error_message(), "");
+ EXPECT_EQ(stream.message, "2");
+}
+
+TEST(StatusTest, PrintsOKStatus) {
+ logging::LoggingStringStream stream;
+ stream << Status::OK;
+ EXPECT_EQ(Status::OK.error_code(), 0);
+ EXPECT_EQ(Status::OK.CanonicalCode(), StatusCode::OK);
+ EXPECT_EQ(Status::OK.error_message(), "");
+ EXPECT_EQ(stream.message, "0");
+}
+
+TEST(StatusTest, UnknownStatusHasRightAttributes) {
+ EXPECT_EQ(Status::UNKNOWN.error_code(), 2);
+ EXPECT_EQ(Status::UNKNOWN.CanonicalCode(), StatusCode::UNKNOWN);
+ EXPECT_EQ(Status::UNKNOWN.error_message(), "");
+}
+
+TEST(StatusTest, OkStatusHasRightAttributes) {
+ EXPECT_EQ(Status::OK.error_code(), 0);
+ EXPECT_EQ(Status::OK.CanonicalCode(), StatusCode::OK);
+ EXPECT_EQ(Status::OK.error_message(), "");
+}
+
+TEST(StatusTest, CustomStatusHasRightAttributes) {
+ Status status(StatusCode::INVALID_ARGUMENT, "You can't put this here!");
+ EXPECT_EQ(status.error_code(), 3);
+ EXPECT_EQ(status.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
+ EXPECT_EQ(status.error_message(), "You can't put this here!");
+}
+
+TEST(StatusTest, AssignmentPreservesMembers) {
+ Status status(StatusCode::INVALID_ARGUMENT, "You can't put this here!");
+
+ Status status2 = status;
+
+ EXPECT_EQ(status2.error_code(), 3);
+ EXPECT_EQ(status2.CanonicalCode(), StatusCode::INVALID_ARGUMENT);
+ EXPECT_EQ(status2.error_message(), "You can't put this here!");
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/base/statusor.h b/native/utils/base/statusor.h
new file mode 100644
index 0000000..b2c719a
--- /dev/null
+++ b/native/utils/base/statusor.h
@@ -0,0 +1,231 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_BASE_STATUSOR_H_
+#define LIBTEXTCLASSIFIER_UTILS_BASE_STATUSOR_H_
+
+#include "utils/base/logging.h"
+#include "utils/base/macros.h"
+#include "utils/base/status.h"
+
+namespace libtextclassifier3 {
+
+// A StatusOr holds a Status (in the case of an error), or a value T.
+template <typename T>
+class StatusOr {
+ public:
+ // Has status UNKNOWN.
+ inline StatusOr();
+
+ // Builds from a non-OK status. Crashes if an OK status is specified.
+ inline StatusOr(const Status& status); // NOLINT
+ inline StatusOr(const Status& status, T&& value); // NOLINT
+
+ // Builds from the specified value.
+ inline StatusOr(const T& value); // NOLINT
+ inline StatusOr(T&& value); // NOLINT
+
+ // Copy constructor.
+ inline StatusOr(const StatusOr& other);
+
+ // Move constructor.
+ inline StatusOr(StatusOr&& other);
+
+ // Conversion copy constructor, T must be copy constructible from U.
+ template <typename U>
+ inline StatusOr(const StatusOr<U>& other); // NOLINT
+
+ // Assignment operator.
+ inline StatusOr& operator=(const StatusOr& other);
+
+ inline StatusOr& operator=(StatusOr&& other);
+
+ // Conversion assignment operator, T must be assignable from U
+ template <typename U>
+ inline StatusOr& operator=(const StatusOr<U>& other);
+
+ // Accessors.
+ inline const Status& status() const { return status_; }
+
+ // Shorthand for status().ok().
+ inline bool ok() const { return status_.ok(); }
+
+ // Returns value or crashes if ok() is false.
+ inline const T& ValueOrDie() const& {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return value_;
+ }
+ inline T& ValueOrDie() & {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return value_;
+ }
+ inline const T&& ValueOrDie() const&& {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return value_;
+ }
+ inline T&& ValueOrDie() && {
+ if (!ok()) {
+ TC3_LOG(FATAL) << "Attempting to fetch value of non-OK StatusOr: "
+ << status();
+ exit(1);
+ }
+ return value_;
+ }
+
+ template <typename U>
+ friend class StatusOr;
+
+ private:
+ Status status_;
+ T value_;
+};
+
+// Implementation.
+
+template <typename T>
+inline StatusOr<T>::StatusOr() : status_(StatusCode::UNKNOWN, "") {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(const Status& status) : status_(status) {
+ if (status.ok()) {
+ TC3_LOG(FATAL) << "OkStatus() is not a valid argument to StatusOr";
+ exit(1);
+ }
+}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(const Status& status, T&& value)
+ : status_(status), value_(std::move(value)) {
+ if (status.ok()) {
+ TC3_LOG(FATAL) << "OkStatus() is not a valid argument to StatusOr";
+ exit(1);
+ }
+}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(const T& value) : value_(value) {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(T&& value) : value_(std::move(value)) {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(const StatusOr& other)
+ : status_(other.status_), value_(other.value_) {}
+
+template <typename T>
+inline StatusOr<T>::StatusOr(StatusOr&& other)
+ : status_(other.status_), value_(std::move(other.value_)) {}
+
+template <typename T>
+template <typename U>
+inline StatusOr<T>::StatusOr(const StatusOr<U>& other)
+ : status_(other.status_), value_(other.value_) {}
+
+template <typename T>
+inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr& other) {
+ status_ = other.status_;
+ if (status_.ok()) {
+ value_ = other.value_;
+ }
+ return *this;
+}
+
+template <typename T>
+inline StatusOr<T>& StatusOr<T>::operator=(StatusOr&& other) {
+ status_ = other.status_;
+ if (status_.ok()) {
+ value_ = std::move(other.value_);
+ }
+ return *this;
+}
+
+template <typename T>
+template <typename U>
+inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr<U>& other) {
+ status_ = other.status_;
+ if (status_.ok()) {
+ value_ = other.value_;
+ }
+ return *this;
+}
+
+} // namespace libtextclassifier3
+
+#define TC3_STATUS_MACROS_CONCAT_NAME(x, y) TC3_STATUS_MACROS_CONCAT_IMPL(x, y)
+#define TC3_STATUS_MACROS_CONCAT_IMPL(x, y) x##y
+
+// Macros that help consume StatusOr<...> return values and propagate errors.
+#define TC3_ASSIGN_OR_RETURN(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN_IMPL( \
+ TC3_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
+ rexpr)
+
+#define TC3_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ if (TC3_PREDICT_FALSE(!statusor.ok())) { \
+ return statusor.status(); \
+ } \
+ lhs = std::move(statusor.ValueOrDie())
+
+#define TC3_ASSIGN_OR_RETURN_NULL(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN_NULL_IMPL( \
+ TC3_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
+ rexpr)
+
+#define TC3_ASSIGN_OR_RETURN_NULL_IMPL(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ if (TC3_PREDICT_FALSE(!statusor.ok())) { \
+ return nullptr; \
+ } \
+ lhs = std::move(statusor.ValueOrDie())
+
+#define TC3_ASSIGN_OR_RETURN_FALSE(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN_FALSE_IMPL( \
+ TC3_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
+ rexpr)
+
+#define TC3_ASSIGN_OR_RETURN_FALSE_IMPL(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ if (TC3_PREDICT_FALSE(!statusor.ok())) { \
+ return false; \
+ } \
+ lhs = std::move(statusor.ValueOrDie())
+
+#define TC3_ASSIGN_OR_RETURN_0(lhs, rexpr) \
+ TC3_ASSIGN_OR_RETURN_0_IMPL( \
+ TC3_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, \
+ rexpr)
+
+#define TC3_ASSIGN_OR_RETURN_0_IMPL(statusor, lhs, rexpr) \
+ auto statusor = (rexpr); \
+ if (TC3_PREDICT_FALSE(!statusor.ok())) { \
+ return 0; \
+ } \
+ lhs = std::move(statusor.ValueOrDie())
+
+#endif // LIBTEXTCLASSIFIER_UTILS_BASE_STATUSOR_H_
diff --git a/native/utils/base/statusor_test.cc b/native/utils/base/statusor_test.cc
new file mode 100644
index 0000000..9e22a3b
--- /dev/null
+++ b/native/utils/base/statusor_test.cc
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/base/statusor.h"
+
+#include "utils/base/logging.h"
+#include "utils/base/status.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(StatusOrTest, DoesntDieWhenOK) {
+ StatusOr<std::string> status_or_string = std::string("Hello World");
+ EXPECT_TRUE(status_or_string.ok());
+ EXPECT_EQ(status_or_string.ValueOrDie(), "Hello World");
+}
+
+TEST(StatusOrTest, DiesWhenNotOK) {
+ StatusOr<std::string> status_or_string = {Status::UNKNOWN};
+ EXPECT_FALSE(status_or_string.ok());
+ EXPECT_DEATH(status_or_string.ValueOrDie(),
+ "Attempting to fetch value of non-OK StatusOr: 2");
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/base/unaligned_access.h b/native/utils/base/unaligned_access.h
index d6907db..2158197 100644
--- a/native/utils/base/unaligned_access.h
+++ b/native/utils/base/unaligned_access.h
@@ -93,10 +93,8 @@
} // namespace libtextclassifier3
-#define TC3_INTERNAL_UNALIGNED_LOAD16(_p) \
- (::libtextclassifier3::UnalignedLoad16(_p))
-#define TC3_INTERNAL_UNALIGNED_LOAD32(_p) \
- (::libtextclassifier3::UnalignedLoad32(_p))
+#define TC3_UNALIGNED_LOAD16(_p) (::libtextclassifier3::UnalignedLoad16(_p))
+#define TC3_UNALIGNED_LOAD32(_p) (::libtextclassifier3::UnalignedLoad32(_p))
#define TC3_UNALIGNED_LOAD64(_p) \
(::libtextclassifier3::UnalignedLoad64(_p))
diff --git a/native/utils/calendar/calendar-common.h b/native/utils/calendar/calendar-common.h
index f47b367..1f5b128 100644
--- a/native/utils/calendar/calendar-common.h
+++ b/native/utils/calendar/calendar-common.h
@@ -41,6 +41,7 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
+ bool prefer_future_for_unspecified_date,
TCalendar* calendar,
DatetimeGranularity* granularity) const;
@@ -49,7 +50,7 @@
private:
// Adjusts the calendar's time instant according to a relative date reference
// in the parsed data.
- bool ApplyRelationField(const DatetimeParsedData& parse_data,
+ bool ApplyRelationField(const DatetimeComponent& relative_date_time_component,
TCalendar* calendar) const;
// Round the time instant's precision down to the given granularity.
@@ -68,11 +69,30 @@
bool allow_today, TCalendar* calendar) const;
};
+inline bool HasOnlyTimeComponents(const DatetimeParsedData& parse_data) {
+ std::vector<DatetimeComponent> components;
+ parse_data.GetDatetimeComponents(&components);
+
+ for (const DatetimeComponent& component : components) {
+ if (!(component.component_type == DatetimeComponent::ComponentType::HOUR ||
+ component.component_type ==
+ DatetimeComponent::ComponentType::MINUTE ||
+ component.component_type ==
+ DatetimeComponent::ComponentType::SECOND ||
+ component.component_type ==
+ DatetimeComponent::ComponentType::MERIDIEM)) {
+ return false;
+ }
+ }
+ return true;
+}
+
template <class TCalendar>
bool CalendarLibTempl<TCalendar>::InterpretParseData(
const DatetimeParsedData& parse_data, int64 reference_time_ms_utc,
const std::string& reference_timezone, const std::string& reference_locale,
- TCalendar* calendar, DatetimeGranularity* granularity) const {
+ bool prefer_future_for_unspecified_date, TCalendar* calendar,
+ DatetimeGranularity* granularity) const {
TC3_CALENDAR_CHECK(calendar->Initialize(reference_timezone, reference_locale,
reference_time_ms_utc))
@@ -98,8 +118,9 @@
std::vector<DatetimeComponent> relative_components;
parse_data.GetRelativeDatetimeComponents(&relative_components);
if (!relative_components.empty()) {
- TC3_CALENDAR_CHECK(ApplyRelationField(parse_data, calendar));
+ // Currently only one relative date time component is possible.
const DatetimeComponent& relative_component = relative_components.back();
+ TC3_CALENDAR_CHECK(ApplyRelationField(relative_component, calendar));
should_round_to_granularity = relative_component.ShouldRoundToGranularity();
} else {
// By default, the parsed time is interpreted to be on the reference day.
@@ -162,22 +183,22 @@
if (should_round_to_granularity) {
TC3_CALENDAR_CHECK(RoundToGranularity(*granularity, calendar))
}
+
+ int64 calendar_millis;
+ TC3_CALENDAR_CHECK(calendar->GetTimeInMillis(&calendar_millis))
+ if (prefer_future_for_unspecified_date &&
+ calendar_millis < reference_time_ms_utc &&
+ HasOnlyTimeComponents(parse_data)) {
+ calendar->AddDayOfMonth(1);
+ }
+
return true;
}
template <class TCalendar>
bool CalendarLibTempl<TCalendar>::ApplyRelationField(
- const DatetimeParsedData& parse_data, TCalendar* calendar) const {
- std::vector<DatetimeComponent> relative_date_time_components;
- parse_data.GetRelativeDatetimeComponents(&relative_date_time_components);
- if (relative_date_time_components.empty()) {
- // There is no relative field set in the parsed data.
- return false;
- }
- // Current only one relative date time component is possible.
- DatetimeComponent relative_date_time_component =
- relative_date_time_components.back();
-
+ const DatetimeComponent& relative_date_time_component,
+ TCalendar* calendar) const {
switch (relative_date_time_component.relative_qualifier) {
case DatetimeComponent::RelativeQualifier::UNSPECIFIED:
TC3_LOG(ERROR) << "UNSPECIFIED RelationType.";
diff --git a/native/utils/calendar/calendar-javaicu.cc b/native/utils/calendar/calendar-javaicu.cc
index 59af9d4..048df04 100644
--- a/native/utils/calendar/calendar-javaicu.cc
+++ b/native/utils/calendar/calendar-javaicu.cc
@@ -17,7 +17,9 @@
#include "utils/calendar/calendar-javaicu.h"
#include "annotator/types.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
namespace libtextclassifier3 {
namespace {
@@ -25,22 +27,26 @@
// Generic version of icu::Calendar::add with error checking.
bool CalendarAdd(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
jint field, jint value) {
- jenv->CallVoidMethod(calendar, jni_cache->calendar_add, field, value);
- return !jni_cache->ExceptionCheckAndClear();
+ return JniHelper::CallVoidMethod(jenv, calendar, jni_cache->calendar_add,
+ field, value)
+ .ok();
}
// Generic version of icu::Calendar::get with error checking.
bool CalendarGet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
jint field, jint* value) {
- *value = jenv->CallIntMethod(calendar, jni_cache->calendar_get, field);
- return !jni_cache->ExceptionCheckAndClear();
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *value,
+ JniHelper::CallIntMethod(jenv, calendar, jni_cache->calendar_get, field));
+ return true;
}
// Generic version of icu::Calendar::set with error checking.
bool CalendarSet(JniCache* jni_cache, JNIEnv* jenv, jobject calendar,
jint field, jint value) {
- jenv->CallVoidMethod(calendar, jni_cache->calendar_set, field, value);
- return !jni_cache->ExceptionCheckAndClear();
+ return JniHelper::CallVoidMethod(jenv, calendar, jni_cache->calendar_set,
+ field, value)
+ .ok();
}
// Extracts the first tag from a BCP47 tag (e.g. "en" for "en-US").
@@ -57,7 +63,8 @@
Calendar::Calendar(JniCache* jni_cache)
: jni_cache_(jni_cache),
- jenv_(jni_cache_ ? jni_cache->GetEnv() : nullptr) {}
+ jenv_(jni_cache_ ? jni_cache->GetEnv() : nullptr),
+ calendar_(nullptr, jenv_) {}
bool Calendar::Initialize(const std::string& time_zone,
const std::string& locale, int64 time_ms_utc) {
@@ -79,51 +86,62 @@
}
// Get the time zone.
- ScopedLocalRef<jstring> java_time_zone_str(
- jenv_->NewStringUTF(time_zone.c_str()));
- ScopedLocalRef<jobject> java_time_zone(jenv_->CallStaticObjectMethod(
- jni_cache_->timezone_class.get(), jni_cache_->timezone_get_timezone,
- java_time_zone_str.get()));
- if (jni_cache_->ExceptionCheckAndClear() || !java_time_zone) {
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> java_time_zone_str,
+ JniHelper::NewStringUTF(jenv_, time_zone.c_str()));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ ScopedLocalRef<jobject> java_time_zone,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->timezone_class.get(),
+ jni_cache_->timezone_get_timezone,
+ java_time_zone_str.get()));
+ if (java_time_zone == nullptr) {
TC3_LOG(ERROR) << "failed to get timezone";
return false;
}
// Get the locale.
- ScopedLocalRef<jobject> java_locale;
+ ScopedLocalRef<jobject> java_locale(nullptr, jenv_);
if (jni_cache_->locale_for_language_tag) {
// API level 21+, we can actually parse language tags.
- ScopedLocalRef<jstring> java_locale_str(
- jenv_->NewStringUTF(locale.c_str()));
- java_locale.reset(jenv_->CallStaticObjectMethod(
- jni_cache_->locale_class.get(), jni_cache_->locale_for_language_tag,
- java_locale_str.get()));
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> java_locale_str,
+ JniHelper::NewStringUTF(jenv_, locale.c_str()));
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ java_locale,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->locale_class.get(),
+ jni_cache_->locale_for_language_tag,
+ java_locale_str.get()));
} else {
// API level <21. We can't parse tags, so we just use the language.
- ScopedLocalRef<jstring> java_language_str(
- jenv_->NewStringUTF(GetFirstBcp47Tag(locale).c_str()));
- java_locale.reset(jenv_->NewObject(jni_cache_->locale_class.get(),
- jni_cache_->locale_init_string,
- java_language_str.get()));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ ScopedLocalRef<jstring> java_language_str,
+ JniHelper::NewStringUTF(jenv_, GetFirstBcp47Tag(locale).c_str()));
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ java_locale, JniHelper::NewObject(jenv_, jni_cache_->locale_class.get(),
+ jni_cache_->locale_init_string,
+ java_language_str.get()));
}
- if (jni_cache_->ExceptionCheckAndClear() || !java_locale) {
+ if (java_locale == nullptr) {
TC3_LOG(ERROR) << "failed to get locale";
return false;
}
// Get the calendar.
- calendar_.reset(jenv_->CallStaticObjectMethod(
- jni_cache_->calendar_class.get(), jni_cache_->calendar_get_instance,
- java_time_zone.get(), java_locale.get()));
- if (jni_cache_->ExceptionCheckAndClear() || !calendar_) {
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ calendar_, JniHelper::CallStaticObjectMethod(
+ jenv_, jni_cache_->calendar_class.get(),
+ jni_cache_->calendar_get_instance, java_time_zone.get(),
+ java_locale.get()));
+ if (calendar_ == nullptr) {
TC3_LOG(ERROR) << "failed to get calendar";
return false;
}
// Set the time.
- jenv_->CallVoidMethod(calendar_.get(),
- jni_cache_->calendar_set_time_in_millis, time_ms_utc);
- if (jni_cache_->ExceptionCheckAndClear()) {
+ if (!JniHelper::CallVoidMethod(jenv_, calendar_.get(),
+ jni_cache_->calendar_set_time_in_millis,
+ time_ms_utc)
+ .ok()) {
TC3_LOG(ERROR) << "failed to set time";
return false;
}
@@ -132,16 +150,23 @@
bool Calendar::GetFirstDayOfWeek(int* value) const {
if (!jni_cache_ || !jenv_ || !calendar_) return false;
- *value = jenv_->CallIntMethod(calendar_.get(),
- jni_cache_->calendar_get_first_day_of_week);
- return !jni_cache_->ExceptionCheckAndClear();
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *value,
+ JniHelper::CallIntMethod(jenv_, calendar_.get(),
+ jni_cache_->calendar_get_first_day_of_week));
+ return true;
}
bool Calendar::GetTimeInMillis(int64* value) const {
if (!jni_cache_ || !jenv_ || !calendar_) return false;
- *value = jenv_->CallLongMethod(calendar_.get(),
- jni_cache_->calendar_get_time_in_millis);
- return !jni_cache_->ExceptionCheckAndClear();
+
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *value,
+ JniHelper::CallLongMethod(jenv_, calendar_.get(),
+ jni_cache_->calendar_get_time_in_millis));
+
+ return true;
}
CalendarLib::CalendarLib() {
diff --git a/native/utils/calendar/calendar-javaicu.h b/native/utils/calendar/calendar-javaicu.h
index 035530e..d6e1716 100644
--- a/native/utils/calendar/calendar-javaicu.h
+++ b/native/utils/calendar/calendar-javaicu.h
@@ -18,14 +18,15 @@
#define LIBTEXTCLASSIFIER_UTILS_CALENDAR_CALENDAR_JAVAICU_H_
#include <jni.h>
+
#include <memory>
#include <string>
#include "annotator/types.h"
#include "utils/base/integral_types.h"
#include "utils/calendar/calendar-common.h"
+#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
namespace libtextclassifier3 {
@@ -71,12 +72,14 @@
int64 reference_time_ms_utc,
const std::string& reference_timezone,
const std::string& reference_locale,
+ bool prefer_future_for_unspecified_date,
int64* interpreted_time_ms_utc,
DatetimeGranularity* granularity) const {
Calendar calendar(jni_cache_.get());
if (!impl_.InterpretParseData(parse_data, reference_time_ms_utc,
reference_timezone, reference_locale,
- &calendar, granularity)) {
+ prefer_future_for_unspecified_date, &calendar,
+ granularity)) {
return false;
}
return calendar.GetTimeInMillis(interpreted_time_ms_utc);
diff --git a/native/utils/calendar/calendar_test-include.cc b/native/utils/calendar/calendar_test-include.cc
index a145fc2..7fe6f53 100644
--- a/native/utils/calendar/calendar_test-include.cc
+++ b/native/utils/calendar/calendar_test-include.cc
@@ -26,8 +26,10 @@
DatetimeGranularity granularity;
std::string timezone;
DatetimeParsedData data;
- bool result = calendarlib_.InterpretParseData(data, 0L, "Zurich", "en-CH",
- &time, &granularity);
+ bool result = calendarlib_.InterpretParseData(
+ data, /*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity);
TC3_LOG(INFO) << result;
}
@@ -40,13 +42,15 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/1L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
}
@@ -59,42 +63,48 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1514761200000L /* Jan 01 2018 00:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::MONTH, 4);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1522533600000L /* Apr 01 2018 00:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::DAY_OF_MONTH, 25);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524607200000L /* Apr 25 2018 00:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524639600000L /* Apr 25 2018 09:00:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 33);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524641580000 /* Apr 25 2018 09:33:00 */);
data.SetAbsoluteValue(DatetimeComponent::ComponentType::SECOND, 59);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-CH", &time, &granularity));
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524641639000 /* Apr 25 2018 09:33:59 */);
}
@@ -110,13 +120,15 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"de-CH", &time, &granularity));
+ /*reference_locale=*/"de-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 342000000L /* Mon Jan 05 1970 00:00:00 */);
ASSERT_TRUE(calendarlib_.InterpretParseData(
data,
/*reference_time_ms_utc=*/0L, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 255600000L /* Sun Jan 04 1970 00:00:00 */);
}
@@ -137,7 +149,8 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
future_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1525858439000L /* Wed May 09 2018 11:33:59 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
@@ -152,7 +165,8 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
next_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1525212000000L /* Wed May 02 2018 00:00:00 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
@@ -167,7 +181,8 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
same_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524607200000L /* Wed Apr 25 2018 00:00:00 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
@@ -182,7 +197,8 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
last_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524002400000L /* Wed Apr 18 2018 00:00:00 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
@@ -197,7 +213,8 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
past_wed_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1523439239000L /* Wed Apr 11 2018 11:33:59 */);
EXPECT_EQ(granularity, GRANULARITY_DAY);
@@ -210,7 +227,8 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
in_3_hours_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524659639000L /* Wed Apr 25 2018 14:33:59 */);
EXPECT_EQ(granularity, GRANULARITY_HOUR);
@@ -224,7 +242,8 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
in_5_minutes_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524649139000L /* Wed Apr 25 2018 14:33:59 */);
EXPECT_EQ(granularity, GRANULARITY_MINUTE);
@@ -238,10 +257,68 @@
ASSERT_TRUE(calendarlib_.InterpretParseData(
in_10_seconds_parse, ref_time, /*reference_timezone=*/"Europe/Zurich",
- /*reference_locale=*/"en-US", &time, &granularity));
+ /*reference_locale=*/"en-US",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
EXPECT_EQ(time, 1524648849000L /* Wed Apr 25 2018 14:33:59 */);
EXPECT_EQ(granularity, GRANULARITY_SECOND);
}
+TEST_F(CalendarTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
+
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", /*prefer_future_for_unspecified_date=*/true,
+ &time, &granularity));
+ EXPECT_EQ(time, 1567401000000L /* Sept 02 2019 07:10:00 */);
+}
+
+TEST_F(CalendarTest,
+ DoesntAddADayWhenTimeInThePastAndDayNotSpecifiedAndDisabled) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 7);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
+
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1567314600000L /* Sept 01 2019 07:10:00 */);
+}
+
+TEST_F(CalendarTest, DoesntAddADayWhenTimeInTheFutureAndDayNotSpecified) {
+ int64 time;
+ DatetimeGranularity granularity;
+ DatetimeParsedData data;
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::HOUR, 9);
+ data.SetAbsoluteValue(DatetimeComponent::ComponentType::MINUTE, 10);
+
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH", /*prefer_future_for_unspecified_date=*/true,
+ &time, &granularity));
+ EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
+
+ ASSERT_TRUE(calendarlib_.InterpretParseData(
+ data,
+ /*reference_time_ms_utc=*/1567317600000L /* Sept 01 2019 00:00:00 */,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*reference_locale=*/"en-CH",
+ /*prefer_future_for_unspecified_date=*/false, &time, &granularity));
+ EXPECT_EQ(time, 1567321800000L /* Sept 01 2019 09:10:00 */);
+}
+
} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers.cc
index 005041d..857d004 100644
--- a/native/utils/flatbuffers.cc
+++ b/native/utils/flatbuffers.cc
@@ -17,6 +17,7 @@
#include "utils/flatbuffers.h"
#include <vector>
+
#include "utils/strings/numbers.h"
#include "utils/variant.h"
@@ -53,6 +54,57 @@
return false;
}
}
+
+// Gets the field information for a field name, returns nullptr if the
+// field was not defined.
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name) {
+ TC3_CHECK(type != nullptr && type->fields() != nullptr);
+ return type->fields()->LookupByKey(field_name.data());
+}
+
+const reflection::Field* GetFieldByOffsetOrNull(const reflection::Object* type,
+ const int field_offset) {
+ if (type->fields() == nullptr) {
+ return nullptr;
+ }
+ for (const reflection::Field* field : *type->fields()) {
+ if (field->offset() == field_offset) {
+ return field;
+ }
+ }
+ return nullptr;
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const StringPiece field_name,
+ const int field_offset) {
+ // Lookup by name might be faster as the fields are sorted by name in the
+ // schema data, so try that first.
+ if (!field_name.empty()) {
+ return GetFieldOrNull(type, field_name.data());
+ }
+ return GetFieldByOffsetOrNull(type, field_offset);
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferField* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ if (field->field_name() == nullptr) {
+ return GetFieldByOffsetOrNull(type, field->field_offset());
+ }
+ return GetFieldOrNull(
+ type,
+ StringPiece(field->field_name()->data(), field->field_name()->size()),
+ field->field_offset());
+}
+
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+ const FlatbufferFieldT* field) {
+ TC3_CHECK(type != nullptr && field != nullptr);
+ return GetFieldOrNull(type, field->field_name, field->field_offset);
+}
+
} // namespace
template <>
@@ -83,17 +135,12 @@
const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
const StringPiece field_name) const {
- return type_->fields()->LookupByKey(field_name.data());
+ return libtextclassifier3::GetFieldOrNull(type_, field_name);
}
const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
const FlatbufferField* field) const {
- // Lookup by name might be faster as the fields are sorted by name in the
- // schema data, so try that first.
- if (field->field_name() != nullptr) {
- return GetFieldOrNull(field->field_name()->str());
- }
- return GetFieldByOffsetOrNull(field->field_offset());
+ return libtextclassifier3::GetFieldOrNull(type_, field);
}
bool ReflectiveFlatbuffer::GetFieldWithParent(
@@ -120,15 +167,7 @@
const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
const int field_offset) const {
- if (type_->fields() == nullptr) {
- return nullptr;
- }
- for (const reflection::Field* field : *type_->fields()) {
- if (field->offset() == field_offset) {
- return field;
- }
- }
- return nullptr;
+ return libtextclassifier3::GetFieldByOffsetOrNull(type_, field_offset);
}
bool ReflectiveFlatbuffer::IsMatchingType(const reflection::Field* field,
@@ -411,4 +450,34 @@
}
}
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path) {
+ if (schema == nullptr || !schema->root_table()) {
+ TC3_LOG(ERROR) << "Empty schema provided.";
+ return false;
+ }
+
+ reflection::Object const* type = schema->root_table();
+ for (int i = 0; i < path->field.size(); i++) {
+ const reflection::Field* field = GetFieldOrNull(type, path->field[i].get());
+ if (field == nullptr) {
+ TC3_LOG(ERROR) << "Could not find field: " << path->field[i]->field_name;
+ return false;
+ }
+ path->field[i]->field_name.clear();
+ path->field[i]->field_offset = field->offset();
+
+ // Descend.
+ if (i < path->field.size() - 1) {
+ if (field->type()->base_type() != reflection::Obj) {
+ TC3_LOG(ERROR) << "Field: " << field->name()->str()
+ << " is not of type `Object`.";
+ return false;
+ }
+ type = schema->objects()->Get(field->type()->index());
+ }
+ }
+ return true;
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers.h
index 17668ff..afa7dc2 100644
--- a/native/utils/flatbuffers.h
+++ b/native/utils/flatbuffers.h
@@ -24,6 +24,7 @@
#include <string>
#include "annotator/model_generated.h"
+#include "utils/flatbuffers_generated.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#include "flatbuffers/flatbuffers.h"
@@ -337,6 +338,10 @@
std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
};
+// Resolves field lookups by name to the concrete field offsets.
+bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
+ FlatbufferFieldPathT* path);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
diff --git a/native/utils/flatbuffers_test.cc b/native/utils/flatbuffers_test.cc
index 348ca73..72838a8 100644
--- a/native/utils/flatbuffers_test.cc
+++ b/native/utils/flatbuffers_test.cc
@@ -14,12 +14,13 @@
* limitations under the License.
*/
+#include "utils/flatbuffers.h"
+
#include <fstream>
#include <map>
#include <memory>
#include <string>
-#include "utils/flatbuffers.h"
#include "utils/flatbuffers_generated.h"
#include "utils/flatbuffers_test_generated.h"
#include "gmock/gmock.h"
@@ -307,5 +308,23 @@
testing::ElementsAreArray({"note i", "note ii", "note iii"}));
}
+TEST(FlatbuffersTest, ResolvesFieldOffsets) {
+ std::string metadata_buffer = LoadTestMetadata();
+ const reflection::Schema* schema =
+ flatbuffers::GetRoot<reflection::Schema>(metadata_buffer.data());
+ FlatbufferFieldPathT path;
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "flight_number";
+ path.field.emplace_back(new FlatbufferFieldT);
+ path.field.back()->field_name = "carrier_code";
+
+ EXPECT_TRUE(SwapFieldNamesForOffsetsInPath(schema, &path));
+
+ EXPECT_THAT(path.field[0]->field_name, testing::IsEmpty());
+ EXPECT_EQ(14, path.field[0]->field_offset);
+ EXPECT_THAT(path.field[1]->field_name, testing::IsEmpty());
+ EXPECT_EQ(4, path.field[1]->field_offset);
+}
+
} // namespace
} // namespace libtextclassifier3
diff --git a/native/utils/intents/intent-generator.cc b/native/utils/intents/intent-generator.cc
index f882515..7dc24e4 100644
--- a/native/utils/intents/intent-generator.cc
+++ b/native/utils/intents/intent-generator.cc
@@ -22,8 +22,10 @@
#include "actions/types.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
+#include "utils/base/statusor.h"
#include "utils/hash/farmhash.h"
#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
#include "utils/java/string_utils.h"
#include "utils/lua-utils.h"
#include "utils/strings/stringpiece.h"
@@ -99,7 +101,7 @@
bool RetrieveSystemResources();
// Parse the url string by using Uri.parse from Java.
- ScopedLocalRef<jobject> ParseUri(StringPiece url) const;
+ StatusOr<ScopedLocalRef<jobject>> ParseUri(StringPiece url) const;
// Read remote action templates from lua generator.
int ReadRemoteActionTemplates(std::vector<RemoteActionTemplate>* result);
@@ -144,10 +146,12 @@
/*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)) {}
bool JniLuaEnvironment::Initialize() {
- string_ =
- MakeGlobalRef(jenv_->NewStringUTF("string"), jenv_, jni_cache_->jvm);
- android_ =
- MakeGlobalRef(jenv_->NewStringUTF("android"), jenv_, jni_cache_->jvm);
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> string_value,
+ JniHelper::NewStringUTF(jenv_, "string"));
+ string_ = MakeGlobalRef(string_value.get(), jenv_, jni_cache_->jvm);
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jstring> android_value,
+ JniHelper::NewStringUTF(jenv_, "android"));
+ android_ = MakeGlobalRef(android_value.get(), jenv_, jni_cache_->jvm);
if (string_ == nullptr || android_ == nullptr) {
TC3_LOG(ERROR) << "Could not allocate constant strings references.";
return false;
@@ -227,15 +231,23 @@
lua_error(state_);
return 0;
}
- ScopedLocalRef<jstring> package_name_str(
- static_cast<jstring>(jenv_->CallObjectMethod(
- context_, jni_cache_->context_get_package_name)));
- if (jni_cache_->ExceptionCheckAndClear()) {
+
+ StatusOr<ScopedLocalRef<jstring>> status_or_package_name_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, context_, jni_cache_->context_get_package_name);
+
+ if (!status_or_package_name_str.ok()) {
TC3_LOG(ERROR) << "Error calling Context.getPackageName";
lua_error(state_);
return 0;
}
- PushString(ToStlString(jenv_, package_name_str.get()));
+ StatusOr<std::string> status_or_package_name_std_str =
+ ToStlString(jenv_, status_or_package_name_str.ValueOrDie().get());
+ if (!status_or_package_name_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_package_name_std_str.ValueOrDie());
return 1;
} else if (key.Equals(kUrlEncodeKey)) {
Bind<JniLuaEnvironment, &JniLuaEnvironment::HandleUrlEncode>();
@@ -270,9 +282,11 @@
return 0;
}
- ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
- usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
- if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
+ StatusOr<ScopedLocalRef<jobject>> status_or_bundle =
+ JniHelper::CallObjectMethod(
+ jenv_, usermanager_.get(),
+ jni_cache_->usermanager_get_user_restrictions);
+ if (!status_or_bundle.ok() || status_or_bundle.ValueOrDie() == nullptr) {
TC3_LOG(ERROR) << "Error calling getUserRestrictions";
lua_error(state_);
return 0;
@@ -285,19 +299,20 @@
return 0;
}
- ScopedLocalRef<jstring> key = jni_cache_->ConvertToJavaString(key_str);
- if (jni_cache_->ExceptionCheckAndClear() || key == nullptr) {
- TC3_LOG(ERROR) << "Expected string, got null.";
+ const StatusOr<ScopedLocalRef<jstring>> status_or_key =
+ jni_cache_->ConvertToJavaString(key_str);
+ if (!status_or_key.ok()) {
lua_error(state_);
return 0;
}
- const bool permission = jenv_->CallBooleanMethod(
- bundle.get(), jni_cache_->bundle_get_boolean, key.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
+ const StatusOr<bool> status_or_permission = JniHelper::CallBooleanMethod(
+ jenv_, status_or_bundle.ValueOrDie().get(),
+ jni_cache_->bundle_get_boolean, status_or_key.ValueOrDie().get());
+ if (!status_or_permission.ok()) {
TC3_LOG(ERROR) << "Error getting bundle value";
lua_pushboolean(state_, false);
} else {
- lua_pushboolean(state_, permission);
+ lua_pushboolean(state_, status_or_permission.ValueOrDie());
}
return 1;
}
@@ -311,42 +326,53 @@
}
// Call Java URL encoder.
- ScopedLocalRef<jstring> input_str = jni_cache_->ConvertToJavaString(input);
- if (jni_cache_->ExceptionCheckAndClear() || input_str == nullptr) {
- TC3_LOG(ERROR) << "Expected string, got null.";
+ const StatusOr<ScopedLocalRef<jstring>> status_or_input_str =
+ jni_cache_->ConvertToJavaString(input);
+ if (!status_or_input_str.ok()) {
lua_error(state_);
return 0;
}
- ScopedLocalRef<jstring> encoded_str(
- static_cast<jstring>(jenv_->CallStaticObjectMethod(
- jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
- input_str.get(), jni_cache_->string_utf8.get())));
- if (jni_cache_->ExceptionCheckAndClear()) {
+ StatusOr<ScopedLocalRef<jstring>> status_or_encoded_str =
+ JniHelper::CallStaticObjectMethod<jstring>(
+ jenv_, jni_cache_->urlencoder_class.get(),
+ jni_cache_->urlencoder_encode, status_or_input_str.ValueOrDie().get(),
+ jni_cache_->string_utf8.get());
+
+ if (!status_or_encoded_str.ok()) {
TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
lua_error(state_);
return 0;
}
- PushString(ToStlString(jenv_, encoded_str.get()));
+ const StatusOr<std::string> status_or_encoded_std_str =
+ ToStlString(jenv_, status_or_encoded_str.ValueOrDie().get());
+ if (!status_or_encoded_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_encoded_std_str.ValueOrDie());
return 1;
}
-ScopedLocalRef<jobject> JniLuaEnvironment::ParseUri(StringPiece url) const {
+StatusOr<ScopedLocalRef<jobject>> JniLuaEnvironment::ParseUri(
+ StringPiece url) const {
if (url.empty()) {
- return nullptr;
+ return {Status::UNKNOWN};
}
// Call to Java URI parser.
- ScopedLocalRef<jstring> url_str = jni_cache_->ConvertToJavaString(url);
- if (jni_cache_->ExceptionCheckAndClear() || url_str == nullptr) {
- TC3_LOG(ERROR) << "Expected string, got null";
- return nullptr;
- }
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> status_or_url_str,
+ jni_cache_->ConvertToJavaString(url));
// Try to parse uri and get scheme.
- ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
- jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
- if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> uri,
+ JniHelper::CallStaticObjectMethod(jenv_, jni_cache_->uri_class.get(),
+ jni_cache_->uri_parse,
+ status_or_url_str.ValueOrDie().get()));
+ if (uri == nullptr) {
TC3_LOG(ERROR) << "Error calling Uri.parse";
+ return {Status::UNKNOWN};
}
return uri;
}
@@ -354,47 +380,64 @@
int JniLuaEnvironment::HandleUrlSchema() {
StringPiece url = ReadString(/*index=*/1);
- ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
- if (parsed_uri == nullptr) {
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
lua_error(state_);
return 0;
}
- ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
- jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_scheme)));
- if (jni_cache_->ExceptionCheckAndClear()) {
+ const StatusOr<ScopedLocalRef<jstring>> status_or_scheme_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_scheme);
+ if (!status_or_scheme_str.ok()) {
TC3_LOG(ERROR) << "Error calling Uri.getScheme";
lua_error(state_);
return 0;
}
- if (scheme_str == nullptr) {
+ if (status_or_scheme_str.ValueOrDie() == nullptr) {
lua_pushnil(state_);
} else {
- PushString(ToStlString(jenv_, scheme_str.get()));
+ const StatusOr<std::string> status_or_scheme_std_str =
+ ToStlString(jenv_, status_or_scheme_str.ValueOrDie().get());
+ if (!status_or_scheme_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_scheme_std_str.ValueOrDie());
}
return 1;
}
int JniLuaEnvironment::HandleUrlHost() {
- StringPiece url = ReadString(/*index=*/-1);
+ const StringPiece url = ReadString(/*index=*/-1);
- ScopedLocalRef<jobject> parsed_uri = ParseUri(url);
- if (parsed_uri == nullptr) {
+ const StatusOr<ScopedLocalRef<jobject>> status_or_parsed_uri = ParseUri(url);
+ if (!status_or_parsed_uri.ok()) {
lua_error(state_);
return 0;
}
- ScopedLocalRef<jstring> host_str(static_cast<jstring>(
- jenv_->CallObjectMethod(parsed_uri.get(), jni_cache_->uri_get_host)));
- if (jni_cache_->ExceptionCheckAndClear()) {
+ const StatusOr<ScopedLocalRef<jstring>> status_or_host_str =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv_, status_or_parsed_uri.ValueOrDie().get(),
+ jni_cache_->uri_get_host);
+ if (!status_or_host_str.ok()) {
TC3_LOG(ERROR) << "Error calling Uri.getHost";
lua_error(state_);
return 0;
}
- if (host_str == nullptr) {
+
+ if (status_or_host_str.ValueOrDie() == nullptr) {
lua_pushnil(state_);
} else {
- PushString(ToStlString(jenv_, host_str.get()));
+ const StatusOr<std::string> status_or_host_std_str =
+ ToStlString(jenv_, status_or_host_str.ValueOrDie().get());
+ if (!status_or_host_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_host_std_str.ValueOrDie());
}
return 1;
}
@@ -458,21 +501,23 @@
lua_error(state_);
return 0;
}
- ScopedLocalRef<jstring> resource_name =
+ const StatusOr<ScopedLocalRef<jstring>> status_or_resource_name =
jni_cache_->ConvertToJavaString(resource_name_str);
- if (resource_name == nullptr) {
+ if (!status_or_resource_name.ok()) {
TC3_LOG(ERROR) << "Invalid resource name.";
lua_error(state_);
return 0;
}
- resource_id = jenv_->CallIntMethod(
- system_resources_.get(), jni_cache_->resources_get_identifier,
- resource_name.get(), string_.get(), android_.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
+ StatusOr<int> status_or_resource_id = JniHelper::CallIntMethod(
+ jenv_, system_resources_.get(), jni_cache_->resources_get_identifier,
+ status_or_resource_name.ValueOrDie().get(), string_.get(),
+ android_.get());
+ if (!status_or_resource_id.ok()) {
TC3_LOG(ERROR) << "Error calling getIdentifier.";
lua_error(state_);
return 0;
}
+ resource_id = status_or_resource_id.ValueOrDie();
break;
}
default:
@@ -485,18 +530,26 @@
lua_pushnil(state_);
return 1;
}
- ScopedLocalRef<jstring> resource_str(static_cast<jstring>(
- jenv_->CallObjectMethod(system_resources_.get(),
- jni_cache_->resources_get_string, resource_id)));
- if (jni_cache_->ExceptionCheckAndClear()) {
+ StatusOr<ScopedLocalRef<jstring>> status_or_resource_str =
+ JniHelper::CallObjectMethod<jstring>(jenv_, system_resources_.get(),
+ jni_cache_->resources_get_string,
+ resource_id);
+ if (!status_or_resource_str.ok()) {
TC3_LOG(ERROR) << "Error calling getString.";
lua_error(state_);
return 0;
}
- if (resource_str == nullptr) {
+
+ if (status_or_resource_str.ValueOrDie() == nullptr) {
lua_pushnil(state_);
} else {
- PushString(ToStlString(jenv_, resource_str.get()));
+ StatusOr<std::string> status_or_resource_std_str =
+ ToStlString(jenv_, status_or_resource_str.ValueOrDie().get());
+ if (!status_or_resource_std_str.ok()) {
+ lua_error(state_);
+ return 0;
+ }
+ PushString(status_or_resource_std_str.ValueOrDie());
}
return 1;
}
@@ -506,14 +559,12 @@
return (system_resources_ != nullptr);
}
system_resources_resources_retrieved_ = true;
- jobject system_resources_ref = jenv_->CallStaticObjectMethod(
- jni_cache_->resources_class.get(), jni_cache_->resources_get_system);
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling getSystem.";
- return false;
- }
+ TC3_ASSIGN_OR_RETURN_FALSE(ScopedLocalRef<jobject> system_resources_ref,
+ JniHelper::CallStaticObjectMethod(
+ jenv_, jni_cache_->resources_class.get(),
+ jni_cache_->resources_get_system));
system_resources_ =
- MakeGlobalRef(system_resources_ref, jenv_, jni_cache_->jvm);
+ MakeGlobalRef(system_resources_ref.get(), jenv_, jni_cache_->jvm);
return (system_resources_ != nullptr);
}
@@ -525,14 +576,15 @@
return (usermanager_ != nullptr);
}
usermanager_retrieved_ = true;
- ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
- jobject usermanager_ref = jenv_->CallObjectMethod(
- context_, jni_cache_->context_get_system_service, service.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
- TC3_LOG(ERROR) << "Error calling getSystemService.";
- return false;
- }
- usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
+ TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> service,
+ JniHelper::NewStringUTF(jenv_, "user"));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ const ScopedLocalRef<jobject> usermanager_ref,
+ JniHelper::CallObjectMethod(jenv_, context_,
+ jni_cache_->context_get_system_service,
+ service.get()));
+
+ usermanager_ = MakeGlobalRef(usermanager_ref.get(), jenv_, jni_cache_->jvm);
return (usermanager_ != nullptr);
}
diff --git a/native/utils/intents/intent-generator.h b/native/utils/intents/intent-generator.h
index 9177adb..0297088 100644
--- a/native/utils/intents/intent-generator.h
+++ b/native/utils/intents/intent-generator.h
@@ -3,6 +3,7 @@
#define LIBTEXTCLASSIFIER_UTILS_INTENTS_INTENT_GENERATOR_H_
#include <jni.h>
+
#include <map>
#include <memory>
#include <string>
@@ -13,7 +14,6 @@
#include "utils/i18n/locale.h"
#include "utils/intents/intent-config_generated.h"
#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_local_ref.h"
#include "utils/optional.h"
#include "utils/resources.h"
#include "utils/resources_generated.h"
diff --git a/native/utils/intents/jni.cc b/native/utils/intents/jni.cc
index d6274b1..fa78be4 100644
--- a/native/utils/intents/jni.cc
+++ b/native/utils/intents/jni.cc
@@ -15,18 +15,27 @@
*/
#include "utils/intents/jni.h"
+
#include <memory>
+
+#include "utils/base/statusor.h"
#include "utils/intents/intent-generator.h"
-#include "utils/java/scoped_local_ref.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
namespace libtextclassifier3 {
// The macros below are intended to reduce the boilerplate and avoid
// easily introduced copy/paste errors.
#define TC3_CHECK_JNI_PTR(PTR) TC3_CHECK((PTR) != nullptr)
-#define TC3_GET_CLASS(FIELD, NAME) \
- handler->FIELD = MakeGlobalRef(env->FindClass(NAME), env, jni_cache->jvm); \
- TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME;
+#define TC3_GET_CLASS(FIELD, NAME) \
+ { \
+ StatusOr<ScopedLocalRef<jclass>> status_or_clazz = \
+ JniHelper::FindClass(env, NAME); \
+ handler->FIELD = MakeGlobalRef(status_or_clazz.ValueOrDie().release(), \
+ env, jni_cache->jvm); \
+ TC3_CHECK_JNI_PTR(handler->FIELD) << "Error finding class: " << NAME; \
+ }
#define TC3_GET_METHOD(CLASS, FIELD, NAME, SIGNATURE) \
handler->FIELD = env->GetMethodID(handler->CLASS.get(), NAME, SIGNATURE); \
TC3_CHECK(handler->FIELD) << "Error finding method: " << NAME;
@@ -73,148 +82,183 @@
return handler;
}
-jstring RemoteActionTemplatesHandler::AsUTF8String(
+StatusOr<ScopedLocalRef<jstring>> RemoteActionTemplatesHandler::AsUTF8String(
const Optional<std::string>& optional) const {
if (!optional.has_value()) {
- return nullptr;
+ return {{nullptr, jni_cache_->GetEnv()}};
}
- return jni_cache_->ConvertToJavaString(optional.value()).release();
+ return jni_cache_->ConvertToJavaString(optional.value());
}
-jobject RemoteActionTemplatesHandler::AsInteger(
+StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsInteger(
const Optional<int>& optional) const {
- return (optional.has_value()
- ? jni_cache_->GetEnv()->NewObject(integer_class_.get(),
- integer_init_, optional.value())
- : nullptr);
+ if (!optional.has_value()) {
+ return {{nullptr, jni_cache_->GetEnv()}};
+ }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(jni_cache_->GetEnv(), integer_class_.get(),
+ integer_init_, optional.value()));
+
+ return result;
}
-jobjectArray RemoteActionTemplatesHandler::AsStringArray(
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::AsStringArray(
const std::vector<std::string>& values) const {
if (values.empty()) {
- return nullptr;
+ return {{nullptr, jni_cache_->GetEnv()}};
}
- jobjectArray result = jni_cache_->GetEnv()->NewObjectArray(
- values.size(), jni_cache_->string_class.get(), nullptr);
- if (result == nullptr) {
- return nullptr;
- }
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> result,
+ JniHelper::NewObjectArray(jni_cache_->GetEnv(), values.size(),
+ jni_cache_->string_class.get(), nullptr));
+
for (int k = 0; k < values.size(); k++) {
- ScopedLocalRef<jstring> value_str =
- jni_cache_->ConvertToJavaString(values[k]);
- jni_cache_->GetEnv()->SetObjectArrayElement(result, k, value_str.get());
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> value_str,
+ jni_cache_->ConvertToJavaString(values[k]));
+ jni_cache_->GetEnv()->SetObjectArrayElement(result.get(), k,
+ value_str.get());
}
return result;
}
-jobject RemoteActionTemplatesHandler::AsNamedVariant(
+StatusOr<ScopedLocalRef<jobject>> RemoteActionTemplatesHandler::AsNamedVariant(
const std::string& name_str, const Variant& value) const {
- ScopedLocalRef<jstring> name = jni_cache_->ConvertToJavaString(name_str);
- if (name == nullptr) {
- return nullptr;
- }
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> name,
+ jni_cache_->ConvertToJavaString(name_str));
+
+ JNIEnv* env = jni_cache_->GetEnv();
switch (value.GetType()) {
case Variant::TYPE_INT_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_int_,
- name.get(), value.IntValue());
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_int_, name.get(),
+ value.IntValue());
+
case Variant::TYPE_INT64_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_long_,
- name.get(), value.Int64Value());
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_long_, name.get(),
+ value.Int64Value());
+
case Variant::TYPE_FLOAT_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_float_,
- name.get(), value.FloatValue());
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_float_, name.get(),
+ value.FloatValue());
+
case Variant::TYPE_DOUBLE_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_double_,
- name.get(), value.DoubleValue());
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_double_, name.get(),
+ value.DoubleValue());
+
case Variant::TYPE_BOOL_VALUE:
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_bool_,
- name.get(), value.BoolValue());
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_bool_, name.get(),
+ value.BoolValue());
+
case Variant::TYPE_STRING_VALUE: {
- ScopedLocalRef<jstring> value_jstring =
- jni_cache_->ConvertToJavaString(value.StringValue());
- if (value_jstring == nullptr) {
- return nullptr;
- }
- return jni_cache_->GetEnv()->NewObject(named_variant_class_.get(),
- named_variant_from_string_,
- name.get(), value_jstring.get());
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jstring> value_jstring,
+ jni_cache_->ConvertToJavaString(value.StringValue()));
+ return JniHelper::NewObject(env, named_variant_class_.get(),
+ named_variant_from_string_, name.get(),
+ value_jstring.get());
}
- default:
- return nullptr;
+
+ case Variant::TYPE_EMPTY:
+ return {Status::UNKNOWN};
}
}
-jobjectArray RemoteActionTemplatesHandler::AsNamedVariantArray(
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::AsNamedVariantArray(
const std::map<std::string, Variant>& values) const {
+ JNIEnv* env = jni_cache_->GetEnv();
if (values.empty()) {
- return nullptr;
+ return {{nullptr, env}};
}
- jobjectArray result = jni_cache_->GetEnv()->NewObjectArray(
- values.size(), named_variant_class_.get(), nullptr);
+
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> result,
+ JniHelper::NewObjectArray(jni_cache_->GetEnv(), values.size(),
+ named_variant_class_.get(), nullptr));
int element_index = 0;
for (auto key_value_pair : values) {
if (!key_value_pair.second.HasValue()) {
element_index++;
continue;
}
- ScopedLocalRef<jobject> named_extra(
- AsNamedVariant(key_value_pair.first, key_value_pair.second),
- jni_cache_->GetEnv());
- if (named_extra == nullptr) {
- return nullptr;
- }
- jni_cache_->GetEnv()->SetObjectArrayElement(result, element_index,
- named_extra.get());
+ TC3_ASSIGN_OR_RETURN(
+ StatusOr<ScopedLocalRef<jobject>> named_extra,
+ AsNamedVariant(key_value_pair.first, key_value_pair.second));
+ env->SetObjectArrayElement(result.get(), element_index,
+ named_extra.ValueOrDie().get());
element_index++;
}
return result;
}
-jobjectArray RemoteActionTemplatesHandler::RemoteActionTemplatesToJObjectArray(
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::RemoteActionTemplatesToJObjectArray(
const std::vector<RemoteActionTemplate>& remote_actions) const {
- const jobjectArray results = jni_cache_->GetEnv()->NewObjectArray(
- remote_actions.size(), remote_action_template_class_.get(), nullptr);
- if (results == nullptr) {
- return nullptr;
- }
+ JNIEnv* env = jni_cache_->GetEnv();
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> results,
+ JniHelper::NewObjectArray(env, remote_actions.size(),
+ remote_action_template_class_.get(), nullptr));
+
for (int i = 0; i < remote_actions.size(); i++) {
const RemoteActionTemplate& remote_action = remote_actions[i];
- const jstring title_without_entity =
- AsUTF8String(remote_action.title_without_entity);
- const jstring title_with_entity =
- AsUTF8String(remote_action.title_with_entity);
- const jstring description = AsUTF8String(remote_action.description);
- const jstring description_with_app_name =
- AsUTF8String(remote_action.description_with_app_name);
- const jstring action = AsUTF8String(remote_action.action);
- const jstring data = AsUTF8String(remote_action.data);
- const jstring type = AsUTF8String(remote_action.type);
- const jobject flags = AsInteger(remote_action.flags);
- const jobjectArray category = AsStringArray(remote_action.category);
- const jstring package = AsUTF8String(remote_action.package_name);
- const jobjectArray extra = AsNamedVariantArray(remote_action.extra);
- const jobject request_code = AsInteger(remote_action.request_code);
- ScopedLocalRef<jobject> result(
- jni_cache_->GetEnv()->NewObject(
- remote_action_template_class_.get(), remote_action_template_init_,
- title_without_entity, title_with_entity, description,
- description_with_app_name, action, data, type, flags, category,
- package, extra, request_code),
- jni_cache_->GetEnv());
- if (result == nullptr) {
- return nullptr;
- }
- jni_cache_->GetEnv()->SetObjectArrayElement(results, i, result.get());
+
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> title_without_entity,
+ AsUTF8String(remote_action.title_without_entity));
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> title_with_entity,
+ AsUTF8String(remote_action.title_with_entity));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> description,
+ AsUTF8String(remote_action.description));
+ TC3_ASSIGN_OR_RETURN(
+ const StatusOr<ScopedLocalRef<jstring>> description_with_app_name,
+ AsUTF8String(remote_action.description_with_app_name));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> action,
+ AsUTF8String(remote_action.action));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> data,
+ AsUTF8String(remote_action.data));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> type,
+ AsUTF8String(remote_action.type));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobject>> flags,
+ AsInteger(remote_action.flags));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobjectArray>> category,
+ AsStringArray(remote_action.category));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jstring>> package,
+ AsUTF8String(remote_action.package_name));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobjectArray>> extra,
+ AsNamedVariantArray(remote_action.extra));
+ TC3_ASSIGN_OR_RETURN(const StatusOr<ScopedLocalRef<jobject>> request_code,
+ AsInteger(remote_action.request_code));
+
+ TC3_ASSIGN_OR_RETURN(
+ const ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(
+ env, remote_action_template_class_.get(),
+ remote_action_template_init_,
+ title_without_entity.ValueOrDie().get(),
+ title_with_entity.ValueOrDie().get(),
+ description.ValueOrDie().get(),
+ description_with_app_name.ValueOrDie().get(),
+ action.ValueOrDie().get(), data.ValueOrDie().get(),
+ type.ValueOrDie().get(), flags.ValueOrDie().get(),
+ category.ValueOrDie().get(), package.ValueOrDie().get(),
+ extra.ValueOrDie().get(), request_code.ValueOrDie().get()));
+ env->SetObjectArrayElement(results.get(), i, result.get());
}
return results;
}
-jobject RemoteActionTemplatesHandler::EntityDataAsNamedVariantArray(
+StatusOr<ScopedLocalRef<jobjectArray>>
+RemoteActionTemplatesHandler::EntityDataAsNamedVariantArray(
const reflection::Schema* entity_data_schema,
const std::string& serialized_entity_data) const {
ReflectiveFlatbufferBuilder entity_data_builder(entity_data_schema);
diff --git a/native/utils/intents/jni.h b/native/utils/intents/jni.h
index 37952a2..0032b8a 100644
--- a/native/utils/intents/jni.h
+++ b/native/utils/intents/jni.h
@@ -18,11 +18,13 @@
#define LIBTEXTCLASSIFIER_UTILS_INTENTS_JNI_H_
#include <jni.h>
+
#include <map>
#include <memory>
#include <string>
#include <vector>
+#include "utils/base/statusor.h"
#include "utils/flatbuffers.h"
#include "utils/intents/intent-generator.h"
#include "utils/java/jni-base.h"
@@ -52,17 +54,21 @@
static std::unique_ptr<RemoteActionTemplatesHandler> Create(
const std::shared_ptr<JniCache>& jni_cache);
- jstring AsUTF8String(const Optional<std::string>& optional) const;
- jobject AsInteger(const Optional<int>& optional) const;
- jobjectArray AsStringArray(const std::vector<std::string>& values) const;
- jobject AsNamedVariant(const std::string& name, const Variant& value) const;
- jobjectArray AsNamedVariantArray(
+ StatusOr<ScopedLocalRef<jstring>> AsUTF8String(
+ const Optional<std::string>& optional) const;
+ StatusOr<ScopedLocalRef<jobject>> AsInteger(
+ const Optional<int>& optional) const;
+ StatusOr<ScopedLocalRef<jobjectArray>> AsStringArray(
+ const std::vector<std::string>& values) const;
+ StatusOr<ScopedLocalRef<jobject>> AsNamedVariant(const std::string& name,
+ const Variant& value) const;
+ StatusOr<ScopedLocalRef<jobjectArray>> AsNamedVariantArray(
const std::map<std::string, Variant>& values) const;
- jobjectArray RemoteActionTemplatesToJObjectArray(
+ StatusOr<ScopedLocalRef<jobjectArray>> RemoteActionTemplatesToJObjectArray(
const std::vector<RemoteActionTemplate>& remote_actions) const;
- jobject EntityDataAsNamedVariantArray(
+ StatusOr<ScopedLocalRef<jobjectArray>> EntityDataAsNamedVariantArray(
const reflection::Schema* entity_data_schema,
const std::string& serialized_entity_data) const;
diff --git a/native/utils/intents/zlib-utils.cc b/native/utils/intents/zlib-utils.cc
index 9f29b46..78489cc 100644
--- a/native/utils/intents/zlib-utils.cc
+++ b/native/utils/intents/zlib-utils.cc
@@ -18,6 +18,7 @@
#include <memory>
+#include "utils/base/logging.h"
#include "utils/zlib/buffer_generated.h"
#include "utils/zlib/zlib.h"
@@ -38,4 +39,33 @@
return true;
}
+bool DecompressIntentModel(IntentFactoryModelT* intent_model) {
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ for (std::unique_ptr<IntentFactoryModel_::IntentGeneratorT>& generator :
+ intent_model->generator) {
+ if (generator->compressed_lua_template_generator == nullptr) {
+ continue;
+ }
+
+ std::string lua_template_generator;
+ if (!zlib_decompressor->MaybeDecompress(
+ generator->compressed_lua_template_generator.get(),
+ &lua_template_generator)) {
+ TC3_LOG(ERROR) << "Cannot decompress intent template.";
+ return false;
+ }
+ generator->lua_template_generator = std::vector<uint8_t>(
+ lua_template_generator.begin(), lua_template_generator.end());
+
+ generator->compressed_lua_template_generator.reset(nullptr);
+ }
+ return true;
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/intents/zlib-utils.h b/native/utils/intents/zlib-utils.h
index afefa3d..b9a370f 100644
--- a/native/utils/intents/zlib-utils.h
+++ b/native/utils/intents/zlib-utils.h
@@ -22,6 +22,7 @@
namespace libtextclassifier3 {
bool CompressIntentModel(IntentFactoryModelT* intent_model);
+bool DecompressIntentModel(IntentFactoryModelT* intent_model);
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-base.cc b/native/utils/java/jni-base.cc
index e04fcf3..e0829b7 100644
--- a/native/utils/java/jni-base.cc
+++ b/native/utils/java/jni-base.cc
@@ -16,13 +16,30 @@
#include "utils/java/jni-base.h"
+#include "utils/base/status.h"
#include "utils/java/string_utils.h"
namespace libtextclassifier3 {
-std::string ToStlString(JNIEnv* env, const jstring& str) {
+bool EnsureLocalCapacity(JNIEnv* env, int capacity) {
+ return env->EnsureLocalCapacity(capacity) == JNI_OK;
+}
+
+bool JniExceptionCheckAndClear(JNIEnv* env) {
+ TC3_CHECK(env != nullptr);
+ const bool result = env->ExceptionCheck();
+ if (result) {
+ env->ExceptionDescribe();
+ env->ExceptionClear();
+ }
+ return result;
+}
+
+StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str) {
std::string result;
- JStringToUtf8String(env, str, &result);
+ if (!JStringToUtf8String(env, str, &result)) {
+ return {Status::UNKNOWN};
+ }
return result;
}
diff --git a/native/utils/java/jni-base.h b/native/utils/java/jni-base.h
index 05bc082..c7b04e6 100644
--- a/native/utils/java/jni-base.h
+++ b/native/utils/java/jni-base.h
@@ -18,8 +18,11 @@
#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_BASE_H_
#include <jni.h>
+
#include <string>
+#include "utils/base/statusor.h"
+
// When we use a macro as an argument for a macro, an additional level of
// indirection is needed, if the macro argument is used with # or ##.
#define TC3_ADD_QUOTES_HELPER(TOKEN) #TOKEN
@@ -58,20 +61,158 @@
namespace libtextclassifier3 {
-template <typename T, typename F>
-std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object,
- jclass class_object, F function,
- const std::string& method_name,
- const std::string& return_java_type) {
- const jmethodID method = env->GetMethodID(class_object, method_name.c_str(),
- ("()" + return_java_type).c_str());
- if (!method) {
- return std::make_pair(false, T());
+// Returns true if the requested capacity is available.
+bool EnsureLocalCapacity(JNIEnv* env, int capacity);
+
+// Returns true if there was an exception. Also it clears the exception.
+bool JniExceptionCheckAndClear(JNIEnv* env);
+
+StatusOr<std::string> ToStlString(JNIEnv* env, const jstring& str);
+
+// A deleter to be used with std::unique_ptr to delete JNI global references.
+class GlobalRefDeleter {
+ public:
+ explicit GlobalRefDeleter(JavaVM* jvm) : jvm_(jvm) {}
+
+ GlobalRefDeleter(const GlobalRefDeleter& orig) = default;
+
+ // Copy assignment to allow move semantics in ScopedGlobalRef.
+ GlobalRefDeleter& operator=(const GlobalRefDeleter& rhs) {
+ TC3_CHECK_EQ(jvm_, rhs.jvm_);
+ return *this;
}
- return std::make_pair(true, (env->*function)(object, method));
+
+ // The delete operator.
+ void operator()(jobject object) const {
+ JNIEnv* env;
+ if (object != nullptr && jvm_ != nullptr &&
+ JNI_OK ==
+ jvm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_4)) {
+ env->DeleteGlobalRef(object);
+ }
+ }
+
+ private:
+ // The jvm_ stashed to use for deletion.
+ JavaVM* const jvm_;
+};
+
+// A deleter to be used with std::unique_ptr to delete JNI local references.
+class LocalRefDeleter {
+ public:
+ explicit LocalRefDeleter(JNIEnv* env)
+ : env_(env) {} // NOLINT(runtime/explicit)
+
+ LocalRefDeleter(const LocalRefDeleter& orig) = default;
+
+ // Copy assignment to allow move semantics in ScopedLocalRef.
+ LocalRefDeleter& operator=(const LocalRefDeleter& rhs) {
+ env_ = rhs.env_;
+ return *this;
+ }
+
+ // The delete operator.
+ void operator()(jobject object) const {
+ if (env_) {
+ env_->DeleteLocalRef(object);
+ }
+ }
+
+ private:
+ // The env_ stashed to use for deletion. Thread-local, don't share!
+ JNIEnv* env_;
+};
+
+// A smart pointer that deletes a reference when it goes out of scope.
+//
+// Note that this class is not thread-safe since it caches JNIEnv in
+// the deleter. Do not use the same jobject across different threads.
+template <typename T, typename Env, typename Deleter>
+class ScopedRef {
+ public:
+ ScopedRef() : ptr_(nullptr, Deleter(nullptr)) {}
+ ScopedRef(T value, Env* env) : ptr_(value, Deleter(env)) {}
+
+ T get() const { return ptr_.get(); }
+
+ T release() { return ptr_.release(); }
+
+ bool operator!() const { return !ptr_; }
+
+ bool operator==(void* value) const { return ptr_.get() == value; }
+
+ explicit operator bool() const { return ptr_ != nullptr; }
+
+ void reset(T value, Env* env) {
+ ptr_.reset(value);
+ ptr_.get_deleter() = Deleter(env);
+ }
+
+ private:
+ std::unique_ptr<typename std::remove_pointer<T>::type, Deleter> ptr_;
+};
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator==(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() == y.get();
}
-std::string ToStlString(JNIEnv* env, const jstring& str);
+template <typename T, typename Env, typename Deleter>
+inline bool operator==(const ScopedRef<T, Env, Deleter>& x, std::nullptr_t) {
+ return x.get() == nullptr;
+}
+
+template <typename T, typename Env, typename Deleter>
+inline bool operator==(std::nullptr_t, const ScopedRef<T, Env, Deleter>& x) {
+ return nullptr == x.get();
+}
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator!=(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() != y.get();
+}
+
+template <typename T, typename Env, typename Deleter>
+inline bool operator!=(const ScopedRef<T, Env, Deleter>& x, std::nullptr_t) {
+ return x.get() != nullptr;
+}
+
+template <typename T, typename Env, typename Deleter>
+inline bool operator!=(std::nullptr_t, const ScopedRef<T, Env, Deleter>& x) {
+ return nullptr != x.get();
+}
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator<(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() < y.get();
+}
+
+template <typename T, typename U, typename Env, typename Deleter>
+inline bool operator>(const ScopedRef<T, Env, Deleter>& x,
+ const ScopedRef<U, Env, Deleter>& y) {
+ return x.get() > y.get();
+}
+
+// A smart pointer that deletes a JNI global reference when it goes out
+// of scope. Usage is:
+// ScopedGlobalRef<jobject> scoped_global(env->JniFunction(), jvm);
+template <typename T>
+using ScopedGlobalRef = ScopedRef<T, JavaVM, GlobalRefDeleter>;
+
+// Ditto, but usage is:
+// ScopedLocalRef<jobject> scoped_local(env->JniFunction(), env);
+template <typename T>
+using ScopedLocalRef = ScopedRef<T, JNIEnv, LocalRefDeleter>;
+
+// A helper to create global references.
+template <typename T>
+ScopedGlobalRef<T> MakeGlobalRef(T object, JNIEnv* env, JavaVM* jvm) {
+ const jobject global_object = env->NewGlobalRef(object);
+ return ScopedGlobalRef<T>(reinterpret_cast<T>(global_object), jvm);
+}
} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-cache.cc b/native/utils/java/jni-cache.cc
index 8c2f00a..32b4708 100644
--- a/native/utils/java/jni-cache.cc
+++ b/native/utils/java/jni-cache.cc
@@ -17,6 +17,8 @@
#include "utils/java/jni-cache.h"
#include "utils/base/logging.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
namespace libtextclassifier3 {
@@ -260,43 +262,34 @@
}
bool JniCache::ExceptionCheckAndClear() const {
- JNIEnv* env = GetEnv();
- TC3_CHECK(env != nullptr);
- const bool result = env->ExceptionCheck();
- if (result) {
- env->ExceptionDescribe();
- env->ExceptionClear();
- }
- return result;
+ return JniExceptionCheckAndClear(GetEnv());
}
-ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
+StatusOr<ScopedLocalRef<jstring>> JniCache::ConvertToJavaString(
const char* utf8_text, const int utf8_text_size_bytes) const {
// Create java byte array.
JNIEnv* jenv = GetEnv();
- const ScopedLocalRef<jbyteArray> text_java_utf8(
- jenv->NewByteArray(utf8_text_size_bytes), jenv);
- if (!text_java_utf8) {
- return nullptr;
- }
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jbyteArray> text_java_utf8,
+ JniHelper::NewByteArray(jenv, utf8_text_size_bytes));
jenv->SetByteArrayRegion(text_java_utf8.get(), 0, utf8_text_size_bytes,
reinterpret_cast<const jbyte*>(utf8_text));
// Create the string with a UTF-8 charset.
- return ScopedLocalRef<jstring>(
- reinterpret_cast<jstring>(
- jenv->NewObject(string_class.get(), string_init_bytes_charset,
- text_java_utf8.get(), string_utf8.get())),
- jenv);
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> result,
+ JniHelper::NewObject<jstring>(
+ jenv, string_class.get(), string_init_bytes_charset,
+ text_java_utf8.get(), string_utf8.get()));
+
+ return result;
}
-ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
+StatusOr<ScopedLocalRef<jstring>> JniCache::ConvertToJavaString(
StringPiece utf8_text) const {
return ConvertToJavaString(utf8_text.data(), utf8_text.size());
}
-ScopedLocalRef<jstring> JniCache::ConvertToJavaString(
+StatusOr<ScopedLocalRef<jstring>> JniCache::ConvertToJavaString(
const UnicodeText& text) const {
return ConvertToJavaString(text.data(), text.size_bytes());
}
diff --git a/native/utils/java/jni-cache.h b/native/utils/java/jni-cache.h
index 609ddb1..ab48419 100644
--- a/native/utils/java/jni-cache.h
+++ b/native/utils/java/jni-cache.h
@@ -18,8 +18,9 @@
#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_CACHE_H_
#include <jni.h>
-#include "utils/java/scoped_global_ref.h"
-#include "utils/java/scoped_local_ref.h"
+
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
#include "utils/strings/stringpiece.h"
#include "utils/utf8/unicodetext.h"
@@ -136,10 +137,12 @@
jmethodID resources_get_string = nullptr;
// Helper to convert lib3 UnicodeText to Java strings.
- ScopedLocalRef<jstring> ConvertToJavaString(
+ StatusOr<ScopedLocalRef<jstring>> ConvertToJavaString(
const char* utf8_text, const int utf8_text_size_bytes) const;
- ScopedLocalRef<jstring> ConvertToJavaString(StringPiece utf8_text) const;
- ScopedLocalRef<jstring> ConvertToJavaString(const UnicodeText& text) const;
+ StatusOr<ScopedLocalRef<jstring>> ConvertToJavaString(
+ StringPiece utf8_text) const;
+ StatusOr<ScopedLocalRef<jstring>> ConvertToJavaString(
+ const UnicodeText& text) const;
private:
explicit JniCache(JavaVM* jvm);
diff --git a/native/utils/java/jni-helper.cc b/native/utils/java/jni-helper.cc
new file mode 100644
index 0000000..c46c751
--- /dev/null
+++ b/native/utils/java/jni-helper.cc
@@ -0,0 +1,131 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/java/jni-helper.h"
+
+namespace libtextclassifier3 {
+
+StatusOr<ScopedLocalRef<jclass>> JniHelper::FindClass(JNIEnv* env,
+ const char* class_name) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jclass> result(env->FindClass(class_name), env);
+ TC3_NOT_NULL_OR_RETURN;
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<jmethodID> JniHelper::GetMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* return_type) {
+ jmethodID result = env->GetMethodID(clazz, method_name, return_type);
+ TC3_NOT_NULL_OR_RETURN;
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jbyteArray>> JniHelper::NewByteArray(JNIEnv* env,
+ jsize length) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jbyteArray> result(env->NewByteArray(length), env);
+ TC3_NOT_NULL_OR_RETURN;
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+Status JniHelper::CallVoidMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ env->CallVoidMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return Status::OK;
+}
+
+StatusOr<bool> JniHelper::CallBooleanMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ bool result = env->CallBooleanMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<int32> JniHelper::CallIntMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jint result = env->CallIntMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<int64> JniHelper::CallLongMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jlong result = env->CallLongMethodV(object, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<int32> JniHelper::CallStaticIntMethod(JNIEnv* env, jclass clazz,
+ jmethodID method_id, ...) {
+ va_list args;
+ va_start(args, method_id);
+ jint result = env->CallStaticIntMethodV(clazz, method_id, args);
+ va_end(args);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jintArray>> JniHelper::NewIntArray(JNIEnv* env,
+ jsize length) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jintArray> result(env->NewIntArray(length), env);
+ TC3_NOT_NULL_OR_RETURN;
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jobjectArray>> JniHelper::NewObjectArray(
+ JNIEnv* env, jsize length, jclass element_class, jobject initial_element) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jobjectArray> result(
+ env->NewObjectArray(length, element_class, initial_element), env);
+ TC3_NOT_NULL_OR_RETURN;
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+StatusOr<ScopedLocalRef<jstring>> JniHelper::NewStringUTF(JNIEnv* env,
+ const char* bytes) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<jstring> result(env->NewStringUTF(bytes), env);
+ TC3_NOT_NULL_OR_RETURN;
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/java/jni-helper.h b/native/utils/java/jni-helper.h
new file mode 100644
index 0000000..8deae80
--- /dev/null
+++ b/native/utils/java/jni-helper.h
@@ -0,0 +1,129 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Utility class that provides similar calls like JNIEnv, but performs
+// additional checks on them, so that it's harder to use them incorrectly.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
+#define LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
+
+#include <jni.h>
+
+#include <string>
+
+#include "utils/base/status.h"
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+
+#define TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN \
+ if (!EnsureLocalCapacity(env, 1)) { \
+ TC3_LOG(ERROR) << "EnsureLocalCapacity(1) failed."; \
+ return {Status::UNKNOWN}; \
+ }
+
+#define TC3_NO_EXCEPTION_OR_RETURN \
+ if (JniExceptionCheckAndClear(env)) { \
+ return {Status::UNKNOWN}; \
+ }
+
+#define TC3_NOT_NULL_OR_RETURN \
+ if (result == nullptr) { \
+ return {Status::UNKNOWN}; \
+ }
+
+#define TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD( \
+ METHOD_NAME, RETURN_TYPE, INPUT_TYPE, POST_CHECK) \
+ template <typename T = RETURN_TYPE> \
+ static StatusOr<ScopedLocalRef<T>> METHOD_NAME( \
+ JNIEnv* env, INPUT_TYPE object, jmethodID method_id, ...) { \
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN; \
+ \
+ va_list args; \
+ va_start(args, method_id); \
+ ScopedLocalRef<T> result( \
+ reinterpret_cast<T>(env->METHOD_NAME##V(object, method_id, args)), \
+ env); \
+ POST_CHECK \
+ va_end(args); \
+ \
+ TC3_NO_EXCEPTION_OR_RETURN; \
+ return result; \
+ }
+
+#define TC3_JNI_NO_CHECK \
+ {}
+
+namespace libtextclassifier3 {
+
+class JniHelper {
+ public:
+ // Misc methods.
+ static StatusOr<ScopedLocalRef<jclass>> FindClass(JNIEnv* env,
+ const char* class_name);
+ template <typename T = jobject>
+ static StatusOr<ScopedLocalRef<T>> GetObjectArrayElement(JNIEnv* env,
+ jobjectArray array,
+ jsize index);
+ static StatusOr<jmethodID> GetMethodID(JNIEnv* env, jclass clazz,
+ const char* method_name,
+ const char* return_type);
+
+ // New* methods.
+ TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(NewObject, jobject, jclass,
+ TC3_NOT_NULL_OR_RETURN);
+ static StatusOr<ScopedLocalRef<jobjectArray>> NewObjectArray(
+ JNIEnv* env, jsize length, jclass element_class,
+ jobject initial_element = nullptr);
+ static StatusOr<ScopedLocalRef<jbyteArray>> NewByteArray(JNIEnv* env,
+ jsize length);
+ static StatusOr<ScopedLocalRef<jintArray>> NewIntArray(JNIEnv* env,
+ jsize length);
+ static StatusOr<ScopedLocalRef<jstring>> NewStringUTF(JNIEnv* env,
+ const char* bytes);
+
+ // Call* methods.
+ TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(CallObjectMethod, jobject,
+ jobject, TC3_JNI_NO_CHECK);
+ TC3_DEFINE_VARIADIC_SCOPED_LOCAL_REF_ENV_METHOD(CallStaticObjectMethod,
+ jobject, jclass,
+ TC3_JNI_NO_CHECK);
+ static Status CallVoidMethod(JNIEnv* env, jobject object, jmethodID method_id,
+ ...);
+ static StatusOr<bool> CallBooleanMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<int32> CallIntMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<int64> CallLongMethod(JNIEnv* env, jobject object,
+ jmethodID method_id, ...);
+ static StatusOr<int32> CallStaticIntMethod(JNIEnv* env, jclass clazz,
+ jmethodID method_id, ...);
+};
+
+template <typename T>
+StatusOr<ScopedLocalRef<T>> JniHelper::GetObjectArrayElement(JNIEnv* env,
+ jobjectArray array,
+ jsize index) {
+ TC3_ENSURE_LOCAL_CAPACITY_OR_RETURN;
+ ScopedLocalRef<T> result(
+ reinterpret_cast<T>(env->GetObjectArrayElement(array, index)), env);
+
+ TC3_NO_EXCEPTION_OR_RETURN;
+ return result;
+}
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_JNI_HELPER_H_
diff --git a/native/utils/java/scoped_global_ref.h b/native/utils/java/scoped_global_ref.h
deleted file mode 100644
index de0608e..0000000
--- a/native/utils/java/scoped_global_ref.h
+++ /dev/null
@@ -1,78 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_
-
-#include <jni.h>
-#include <memory>
-#include <type_traits>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-// A deleter to be used with std::unique_ptr to delete JNI global references.
-class GlobalRefDeleter {
- public:
- GlobalRefDeleter() : jvm_(nullptr) {}
-
- // Style guide violating implicit constructor so that the GlobalRefDeleter
- // is implicitly constructed from the second argument to ScopedGlobalRef.
- GlobalRefDeleter(JavaVM* jvm) : jvm_(jvm) {} // NOLINT(runtime/explicit)
-
- GlobalRefDeleter(const GlobalRefDeleter& orig) = default;
-
- // Copy assignment to allow move semantics in ScopedGlobalRef.
- GlobalRefDeleter& operator=(const GlobalRefDeleter& rhs) {
- TC3_CHECK_EQ(jvm_, rhs.jvm_);
- return *this;
- }
-
- // The delete operator.
- void operator()(jobject object) const {
- JNIEnv* env;
- if (object != nullptr && jvm_ != nullptr &&
- JNI_OK ==
- jvm_->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_4)) {
- env->DeleteGlobalRef(object);
- }
- }
-
- private:
- // The jvm_ stashed to use for deletion.
- JavaVM* const jvm_;
-};
-
-// A smart pointer that deletes a JNI global reference when it goes out
-// of scope. Usage is:
-// ScopedGlobalRef<jobject> scoped_global(env->JniFunction(), jvm);
-template <typename T>
-using ScopedGlobalRef =
- std::unique_ptr<typename std::remove_pointer<T>::type, GlobalRefDeleter>;
-
-// A helper to create global references. Assumes the object has a local
-// reference, which it deletes.
-template <typename T>
-ScopedGlobalRef<T> MakeGlobalRef(T object, JNIEnv* env, JavaVM* jvm) {
- const jobject global_object = env->NewGlobalRef(object);
- env->DeleteLocalRef(object);
- return ScopedGlobalRef<T>(reinterpret_cast<T>(global_object), jvm);
-}
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_GLOBAL_REF_H_
diff --git a/native/utils/java/scoped_local_ref.h b/native/utils/java/scoped_local_ref.h
deleted file mode 100644
index f439c45..0000000
--- a/native/utils/java/scoped_local_ref.h
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_
-#define LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_
-
-#include <jni.h>
-#include <memory>
-#include <type_traits>
-
-#include "utils/base/logging.h"
-
-namespace libtextclassifier3 {
-
-// A deleter to be used with std::unique_ptr to delete JNI local references.
-class LocalRefDeleter {
- public:
- LocalRefDeleter() : env_(nullptr) {}
-
- // Style guide violating implicit constructor so that the LocalRefDeleter
- // is implicitly constructed from the second argument to ScopedLocalRef.
- LocalRefDeleter(JNIEnv* env) : env_(env) {} // NOLINT(runtime/explicit)
-
- LocalRefDeleter(const LocalRefDeleter& orig) = default;
-
- // Copy assignment to allow move semantics in ScopedLocalRef.
- LocalRefDeleter& operator=(const LocalRefDeleter& rhs) {
- // As the deleter and its state are thread-local, ensure the envs
- // are consistent but do nothing.
- TC3_CHECK_EQ(env_, rhs.env_);
- return *this;
- }
-
- // The delete operator.
- void operator()(jobject object) const {
- if (env_) {
- env_->DeleteLocalRef(object);
- }
- }
-
- private:
- // The env_ stashed to use for deletion. Thread-local, don't share!
- JNIEnv* const env_;
-};
-
-// A smart pointer that deletes a JNI local reference when it goes out
-// of scope. Usage is:
-// ScopedLocalRef<jobject> scoped_local(env->JniFunction(), env);
-//
-// Note that this class is not thread-safe since it caches JNIEnv in
-// the deleter. Do not use the same jobject across different threads.
-template <typename T>
-using ScopedLocalRef =
- std::unique_ptr<typename std::remove_pointer<T>::type, LocalRefDeleter>;
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_UTILS_JAVA_SCOPED_LOCAL_REF_H_
diff --git a/native/utils/java/string_utils.cc b/native/utils/java/string_utils.cc
index 457a667..ca518a0 100644
--- a/native/utils/java/string_utils.cc
+++ b/native/utils/java/string_utils.cc
@@ -39,7 +39,7 @@
std::string* result) {
if (jstr == nullptr) {
*result = std::string();
- return false;
+ return true;
}
jclass string_class = env->FindClass("java/lang/String");
diff --git a/native/utils/lua-utils.cc b/native/utils/lua-utils.cc
index 64071ca..4da3171 100644
--- a/native/utils/lua-utils.cc
+++ b/native/utils/lua-utils.cc
@@ -133,7 +133,7 @@
const flatbuffers::String *string_value =
table->GetPointer<const flatbuffers::String *>(field->offset());
if (string_value != nullptr) {
- lua_pushlstring(state, string_value->data(), string_value->Length());
+ lua_pushlstring(state, string_value->data(), string_value->size());
} else {
lua_pushlstring(state, "", 0);
}
diff --git a/native/utils/normalization.cc b/native/utils/normalization.cc
new file mode 100644
index 0000000..fd64dbb
--- /dev/null
+++ b/native/utils/normalization.cc
@@ -0,0 +1,75 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/normalization.h"
+
+#include "utils/base/logging.h"
+#include "utils/normalization_generated.h"
+
+namespace libtextclassifier3 {
+
+UnicodeText NormalizeText(const UniLib* unilib,
+ const NormalizationOptions* normalization_options,
+ const UnicodeText& text) {
+ return NormalizeTextCodepointWise(
+ unilib, normalization_options->codepointwise_normalization(), text);
+}
+
+UnicodeText NormalizeTextCodepointWise(const UniLib* unilib,
+ const uint32 codepointwise_ops,
+ const UnicodeText& text) {
+ // Sanity check.
+ TC3_CHECK(!((codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE) &&
+ (codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE)));
+
+ UnicodeText result;
+ for (const char32 codepoint : text) {
+ // Skip whitespace.
+ if ((codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE) &&
+ unilib->IsWhitespace(codepoint)) {
+ continue;
+ }
+
+ // Skip punctuation.
+ if ((codepointwise_ops &
+ NormalizationOptions_::
+ CodepointwiseNormalizationOp_DROP_PUNCTUATION) &&
+ unilib->IsPunctuation(codepoint)) {
+ continue;
+ }
+
+ int32 normalized_codepoint = codepoint;
+
+ // Lower case.
+ if (codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE) {
+ normalized_codepoint = unilib->ToLower(normalized_codepoint);
+
+ // Upper case.
+ } else if (codepointwise_ops &
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE) {
+ normalized_codepoint = unilib->ToUpper(normalized_codepoint);
+ }
+
+ result.push_back(normalized_codepoint);
+ }
+ return result;
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/normalization.fbs b/native/utils/normalization.fbs
new file mode 100755
index 0000000..4d43f10
--- /dev/null
+++ b/native/utils/normalization.fbs
@@ -0,0 +1,40 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
+// The possible codepoint wise normalization operations.
+namespace libtextclassifier3.NormalizationOptions_;
+enum CodepointwiseNormalizationOp : int {
+ NONE = 0,
+
+ // Lower-case the string.
+ LOWERCASE = 1,
+
+ // Upper-case the string.
+ UPPERCASE = 4,
+
+ // Remove whitespace.
+ DROP_WHITESPACE = 8,
+
+ // Remove punctuation.
+ DROP_PUNCTUATION = 16,
+}
+
+namespace libtextclassifier3;
+table NormalizationOptions {
+ // Codepoint wise normalizations to apply, represents a bit field.
+ codepointwise_normalization:uint;
+}
+
diff --git a/native/utils/normalization.h b/native/utils/normalization.h
new file mode 100644
index 0000000..0ded163
--- /dev/null
+++ b/native/utils/normalization.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Methods for string normalization.
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_NORMALIZATION_H_
+#define LIBTEXTCLASSIFIER_UTILS_NORMALIZATION_H_
+
+#include "utils/base/integral_types.h"
+#include "utils/normalization_generated.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+
+namespace libtextclassifier3 {
+
+// Normalizes a text according to the options.
+UnicodeText NormalizeText(const UniLib* unilib,
+ const NormalizationOptions* normalization_options,
+ const UnicodeText& text);
+
+// Normalizes a text codepoint wise by applying each codepoint wise op in
+// `codepointwise_ops` that is interpreted as a set of
+// `CodepointwiseNormalizationOp`.
+UnicodeText NormalizeTextCodepointWise(const UniLib* unilib,
+ const uint32 codepointwise_ops,
+ const UnicodeText& text);
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_NORMALIZATION_H_
diff --git a/native/utils/normalization_test.cc b/native/utils/normalization_test.cc
new file mode 100644
index 0000000..1bf9fae
--- /dev/null
+++ b/native/utils/normalization_test.cc
@@ -0,0 +1,121 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/normalization.h"
+
+#include <string>
+
+#include "utils/base/integral_types.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+namespace {
+
+using testing::Eq;
+
+class NormalizationTest : public testing::Test {
+ protected:
+ NormalizationTest() : INIT_UNILIB_FOR_TESTING(unilib_) {}
+
+ std::string NormalizeTextCodepointWise(const std::string& text,
+ const int32 codepointwise_ops) {
+ return libtextclassifier3::NormalizeTextCodepointWise(
+ &unilib_, codepointwise_ops,
+ UTF8ToUnicodeText(text, /*do_copy=*/false))
+ .ToUTF8String();
+ }
+
+ UniLib unilib_;
+};
+
+TEST_F(NormalizationTest, ReturnsIdenticalStringWhenNoNormalization) {
+ EXPECT_THAT(NormalizeTextCodepointWise(
+ "Never gonna let you down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_NONE),
+ Eq("Never gonna let you down."));
+}
+
+#if !defined(TC3_UNILIB_DUMMY)
+TEST_F(NormalizationTest, DropsWhitespace) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never gonna let you down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
+ Eq("Nevergonnaletyoudown."));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never\tgonna\t\tlet\tyou\tdown.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
+ Eq("Nevergonnaletyoudown."));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never\u2003gonna\u2003let\u2003you\u2003down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE),
+ Eq("Nevergonnaletyoudown."));
+}
+
+TEST_F(NormalizationTest, DropsPunctuation) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Never gonna let you down.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
+ Eq("Never gonna let you down"));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
+ Eq("αʹ Σημεῖόν ἐστιν οὗ μέρος οὐθέν"));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "978—3—16—148410—0",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_PUNCTUATION),
+ Eq("9783161484100"));
+}
+
+TEST_F(NormalizationTest, LowercasesUnicodeText) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE),
+ Eq("αʹ. σημεῖόν ἐστιν, οὗ μέρος οὐθέν."));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "αʹ. Σημεῖόν ἐστιν, οὗ μέρος οὐθέν.",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
+ NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE),
+ Eq("αʹ.σημεῖόνἐστιν,οὗμέροςοὐθέν."));
+}
+
+TEST_F(NormalizationTest, UppercasesUnicodeText) {
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Κανένας άνθρωπος δεν ξέρει",
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE),
+ Eq("ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ"));
+ EXPECT_THAT(
+ NormalizeTextCodepointWise(
+ "Κανένας άνθρωπος δεν ξέρει",
+ NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
+ NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE),
+ Eq("ΚΑΝΈΝΑΣΆΝΘΡΩΠΟΣΔΕΝΞΈΡΕΙ"));
+}
+#endif
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/resources.cc b/native/utils/resources.cc
index ddfa499..2ae2def 100644
--- a/native/utils/resources.cc
+++ b/native/utils/resources.cc
@@ -15,6 +15,7 @@
*/
#include "utils/resources.h"
+
#include "utils/base/logging.h"
#include "utils/zlib/buffer_generated.h"
#include "utils/zlib/zlib.h"
@@ -214,4 +215,34 @@
builder.GetSize());
}
+bool DecompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary) {
+ std::vector<unsigned char> dictionary;
+
+ for (auto& entry : resources->resource_entry) {
+ for (auto& resource : entry->resource) {
+ if (resource->compressed_content == nullptr) {
+ continue;
+ }
+
+ std::unique_ptr<ZlibDecompressor> zlib_decompressor =
+ build_compression_dictionary
+ ? ZlibDecompressor::Instance(dictionary.data(), dictionary.size())
+ : ZlibDecompressor::Instance();
+ if (!zlib_decompressor) {
+ TC3_LOG(ERROR) << "Cannot initialize decompressor.";
+ return false;
+ }
+
+ if (!zlib_decompressor->MaybeDecompress(
+ resource->compressed_content.get(), &resource->content)) {
+ TC3_LOG(ERROR) << "Cannot decompress resource.";
+ return false;
+ }
+ resource->compressed_content.reset(nullptr);
+ }
+ }
+ return true;
+}
+
} // namespace libtextclassifier3
diff --git a/native/utils/resources.h b/native/utils/resources.h
index 28db0cc..81647bb 100644
--- a/native/utils/resources.h
+++ b/native/utils/resources.h
@@ -71,6 +71,9 @@
const bool build_compression_dictionary = false,
const int dictionary_sample_every = 1);
+bool DecompressResources(ResourcePoolT* resources,
+ const bool build_compression_dictionary = false);
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_RESOURCES_H_
diff --git a/native/utils/strings/stringpiece.h b/native/utils/strings/stringpiece.h
index 0dec1b8..cb45a0a 100644
--- a/native/utils/strings/stringpiece.h
+++ b/native/utils/strings/stringpiece.h
@@ -17,7 +17,7 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
#define LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
-#include <stddef.h>
+#include <cstddef>
#include <string>
#include "utils/base/logging.h"
@@ -29,23 +29,23 @@
public:
StringPiece() : StringPiece(nullptr, 0) {}
- StringPiece(const char *str) // NOLINT(runtime/explicit)
+ StringPiece(const char* str) // NOLINT(runtime/explicit)
: start_(str), size_(str == nullptr ? 0 : strlen(str)) {}
- StringPiece(const char *start, size_t size) : start_(start), size_(size) {}
+ StringPiece(const char* start, size_t size) : start_(start), size_(size) {}
// Intentionally no "explicit" keyword: in function calls, we want strings to
// be converted to StringPiece implicitly.
- StringPiece(const std::string &s) // NOLINT(runtime/explicit)
+ StringPiece(const std::string& s) // NOLINT(runtime/explicit)
: StringPiece(s.data(), s.size()) {}
- StringPiece(const std::string &s, int offset, int len)
+ StringPiece(const std::string& s, int offset, int len)
: StringPiece(s.data() + offset, len) {}
char operator[](size_t i) const { return start_[i]; }
// Returns start address of underlying data.
- const char *data() const { return start_; }
+ const char* data() const { return start_; }
// Returns number of bytes of underlying data.
size_t size() const { return size_; }
@@ -83,7 +83,7 @@
}
private:
- const char *start_; // Not owned.
+ const char* start_; // Not owned.
size_t size_;
};
@@ -95,7 +95,7 @@
return text.StartsWith(prefix);
}
-inline bool ConsumePrefix(StringPiece *text, StringPiece prefix) {
+inline bool ConsumePrefix(StringPiece* text, StringPiece prefix) {
if (!text->StartsWith(prefix)) {
return false;
}
@@ -103,6 +103,12 @@
return true;
}
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, StringPiece message) {
+ stream.message.append(message.data(), message.size());
+ return stream;
+}
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_STRINGS_STRINGPIECE_H_
diff --git a/native/utils/tflite-model-executor.cc b/native/utils/tflite-model-executor.cc
index 4ad60cd..8bdefde 100644
--- a/native/utils/tflite-model-executor.cc
+++ b/native/utils/tflite-model-executor.cc
@@ -191,7 +191,7 @@
const tflite::Model* model =
flatbuffers::GetRoot<tflite::Model>(model_spec_buffer->data());
flatbuffers::Verifier verifier(model_spec_buffer->data(),
- model_spec_buffer->Length());
+ model_spec_buffer->size());
if (!model->Verify(verifier)) {
return nullptr;
}
diff --git a/native/utils/tflite-model-executor.h b/native/utils/tflite-model-executor.h
index e9c6af9..a4432ff 100644
--- a/native/utils/tflite-model-executor.h
+++ b/native/utils/tflite-model-executor.h
@@ -19,11 +19,13 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
#define LIBTEXTCLASSIFIER_UTILS_TFLITE_MODEL_EXECUTOR_H_
+#include <cstdint>
#include <memory>
#include "utils/base/logging.h"
#include "utils/tensor-view.h"
#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
#include "tensorflow/lite/op_resolver.h"
@@ -85,25 +87,25 @@
interpreter->tensor(interpreter->inputs()[input_index]);
switch (input_tensor->type) {
case kTfLiteFloat32:
- *(input_tensor->data.f) = input_value;
+ *tflite::GetTensorData<float>(input_tensor) = input_value;
break;
case kTfLiteInt32:
- *(input_tensor->data.i32) = input_value;
+ *tflite::GetTensorData<int32_t>(input_tensor) = input_value;
break;
case kTfLiteUInt8:
- *(input_tensor->data.uint8) = input_value;
+ *tflite::GetTensorData<uint8_t>(input_tensor) = input_value;
break;
case kTfLiteInt64:
- *(input_tensor->data.i64) = input_value;
+ *tflite::GetTensorData<int64_t>(input_tensor) = input_value;
break;
case kTfLiteBool:
- *(input_tensor->data.b) = input_value;
+ *tflite::GetTensorData<bool>(input_tensor) = input_value;
break;
case kTfLiteInt16:
- *(input_tensor->data.i16) = input_value;
+ *tflite::GetTensorData<int16_t>(input_tensor) = input_value;
break;
case kTfLiteInt8:
- *(input_tensor->data.int8) = input_value;
+ *tflite::GetTensorData<int8_t>(input_tensor) = input_value;
break;
default:
break;
diff --git a/native/utils/tflite/text_encoder.cc b/native/utils/tflite/text_encoder.cc
index c7811ea..68380ea 100644
--- a/native/utils/tflite/text_encoder.cc
+++ b/native/utils/tflite/text_encoder.cc
@@ -91,7 +91,7 @@
const TrieNode* charsmap_trie_nodes = reinterpret_cast<const TrieNode*>(
config->normalization_charsmap()->Data());
const int charsmap_trie_nodes_length =
- config->normalization_charsmap()->Length() / sizeof(TrieNode);
+ config->normalization_charsmap()->size() / sizeof(TrieNode);
encoder_op->normalizer.reset(new SentencePieceNormalizer(
DoubleArrayTrie(charsmap_trie_nodes, charsmap_trie_nodes_length),
StringPiece(config->normalization_charsmap_values()->data(),
diff --git a/native/utils/tflite/text_encoder_test.cc b/native/utils/tflite/text_encoder_test.cc
index ae752f5..6386432 100644
--- a/native/utils/tflite/text_encoder_test.cc
+++ b/native/utils/tflite/text_encoder_test.cc
@@ -39,7 +39,7 @@
public:
TextEncoderOpModel(std::initializer_list<int> input_strings_shape,
std::initializer_list<int> attribute_shape);
- void SetInputText(const std::initializer_list<string>& strings) {
+ void SetInputText(const std::initializer_list<std::string>& strings) {
PopulateStringTensor(input_string_, strings);
PopulateTensor(input_length_, {static_cast<int32_t>(strings.size())});
}
diff --git a/native/utils/utf8/unicodetext.cc b/native/utils/utf8/unicodetext.cc
index b3b092e..d58096f 100644
--- a/native/utils/utf8/unicodetext.cc
+++ b/native/utils/utf8/unicodetext.cc
@@ -320,8 +320,8 @@
return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
}
-UnicodeText UTF8ToUnicodeText(const std::string& str) {
- return UTF8ToUnicodeText(str, /*do_copy=*/true);
+UnicodeText UTF8ToUnicodeText(StringPiece str, bool do_copy) {
+ return UTF8ToUnicodeText(str.data(), str.size(), do_copy);
}
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unicodetext.h b/native/utils/utf8/unicodetext.h
index 3f884f9..de3b742 100644
--- a/native/utils/utf8/unicodetext.h
+++ b/native/utils/utf8/unicodetext.h
@@ -22,6 +22,8 @@
#include <utility>
#include "utils/base/integral_types.h"
+#include "utils/base/logging.h"
+#include "utils/strings/stringpiece.h"
namespace libtextclassifier3 {
@@ -222,7 +224,13 @@
bool do_copy = true);
UnicodeText UTF8ToUnicodeText(const char* utf8_buf, bool do_copy = true);
UnicodeText UTF8ToUnicodeText(const std::string& str, bool do_copy = true);
-UnicodeText UTF8ToUnicodeText(const std::string& str);
+UnicodeText UTF8ToUnicodeText(StringPiece str, bool do_copy = true);
+
+inline logging::LoggingStringStream& operator<<(
+ logging::LoggingStringStream& stream, const UnicodeText& message) {
+ stream.message.append(message.data(), message.size_bytes());
+ return stream;
+}
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unicodetext_test.cc b/native/utils/utf8/unicodetext_test.cc
index e6926ce..48d4da8 100644
--- a/native/utils/utf8/unicodetext_test.cc
+++ b/native/utils/utf8/unicodetext_test.cc
@@ -16,6 +16,7 @@
#include "utils/utf8/unicodetext.h"
+#include "utils/strings/stringpiece.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
@@ -49,6 +50,21 @@
EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
}
+TEST(UnicodeTextTest, StringPieceView) {
+ std::string raw_text = "1234😋hello";
+ UnicodeText text =
+ UTF8ToUnicodeText(StringPiece(raw_text), /*do_copy=*/false);
+ EXPECT_EQ(text.ToUTF8String(), "1234😋hello");
+ EXPECT_EQ(text.size_codepoints(), 10);
+ EXPECT_EQ(text.size_bytes(), 13);
+
+ auto it_begin = text.begin();
+ std::advance(it_begin, 4);
+ auto it_end = text.begin();
+ std::advance(it_end, 6);
+ EXPECT_EQ(text.UTF8Substring(it_begin, it_end), "😋h");
+}
+
TEST(UnicodeTextTest, Substring) {
UnicodeText text = UTF8ToUnicodeText("1234😋hello", /*do_copy=*/false);
diff --git a/native/utils/utf8/unilib-common.cc b/native/utils/utf8/unilib-common.cc
index 2b6deda..3b466ed 100644
--- a/native/utils/utf8/unilib-common.cc
+++ b/native/utils/utf8/unilib-common.cc
@@ -52,12 +52,12 @@
// grep -E "WS" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
constexpr char32 kWhitespaces[] = {
- 0x000C, 0x0020, 0x1680, 0x2000, 0x2001, 0x2002, 0x2003, 0x2004,
- 0x2005, 0x2006, 0x2007, 0x2008, 0x2009, 0x200A, 0x2028, 0x205F,
- 0x21C7, 0x21C8, 0x21C9, 0x21CA, 0x21F6, 0x2B31, 0x2B84, 0x2B85,
- 0x2B86, 0x2B87, 0x2B94, 0x3000, 0x4DCC, 0x10344, 0x10347, 0x1DA0A,
- 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, 0x1F4F0, 0x1F500,
- 0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
+ 0x0009, 0x000C, 0x0020, 0x1680, 0x2000, 0x2001, 0x2002, 0x2003,
+ 0x2004, 0x2005, 0x2006, 0x2007, 0x2008, 0x2009, 0x200A, 0x2028,
+ 0x205F, 0x21C7, 0x21C8, 0x21C9, 0x21CA, 0x21F6, 0x2B31, 0x2B84,
+ 0x2B85, 0x2B86, 0x2B87, 0x2B94, 0x3000, 0x4DCC, 0x10344, 0x10347,
+ 0x1DA0A, 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, 0x1F4F0,
+ 0x1F500, 0x1F501, 0x1F502, 0x1F503, 0x1F504, 0x1F5D8, 0x1F5DE};
constexpr int kNumWhitespaces = ARRAYSIZE(kWhitespaces);
// grep -E "Nd" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
@@ -73,6 +73,58 @@
0x118e9, 0x11c59, 0x11d59, 0x16a69, 0x16b59, 0x1d7ff};
constexpr int kNumDecimalDigitRangesEnd = ARRAYSIZE(kDecimalDigitRangesEnd);
+// grep -E ";P.;" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
+constexpr char32 kPunctuationRangesStart[] = {
+ 0x0021, 0x0025, 0x002c, 0x003a, 0x003f, 0x005b, 0x005f, 0x007b,
+ 0x007d, 0x00a1, 0x00a7, 0x00ab, 0x00b6, 0x00bb, 0x00bf, 0x037e,
+ 0x0387, 0x055a, 0x0589, 0x05be, 0x05c0, 0x05c3, 0x05c6, 0x05f3,
+ 0x0609, 0x060c, 0x061b, 0x061e, 0x066a, 0x06d4, 0x0700, 0x07f7,
+ 0x0830, 0x085e, 0x0964, 0x0970, 0x09fd, 0x0a76, 0x0af0, 0x0c77,
+ 0x0c84, 0x0df4, 0x0e4f, 0x0e5a, 0x0f04, 0x0f14, 0x0f3a, 0x0f85,
+ 0x0fd0, 0x0fd9, 0x104a, 0x10fb, 0x1360, 0x1400, 0x166e, 0x169b,
+ 0x16eb, 0x1735, 0x17d4, 0x17d8, 0x1800, 0x1944, 0x1a1e, 0x1aa0,
+ 0x1aa8, 0x1b5a, 0x1bfc, 0x1c3b, 0x1c7e, 0x1cc0, 0x1cd3, 0x2010,
+ 0x2030, 0x2045, 0x2053, 0x207d, 0x208d, 0x2308, 0x2329, 0x2768,
+ 0x27c5, 0x27e6, 0x2983, 0x29d8, 0x29fc, 0x2cf9, 0x2cfe, 0x2d70,
+ 0x2e00, 0x2e30, 0x3001, 0x3008, 0x3014, 0x3030, 0x303d, 0x30a0,
+ 0x30fb, 0xa4fe, 0xa60d, 0xa673, 0xa67e, 0xa6f2, 0xa874, 0xa8ce,
+ 0xa8f8, 0xa8fc, 0xa92e, 0xa95f, 0xa9c1, 0xa9de, 0xaa5c, 0xaade,
+ 0xaaf0, 0xabeb, 0xfd3e, 0xfe10, 0xfe30, 0xfe54, 0xfe63, 0xfe68,
+ 0xfe6a, 0xff01, 0xff05, 0xff0c, 0xff1a, 0xff1f, 0xff3b, 0xff3f,
+ 0xff5b, 0xff5d, 0xff5f, 0x10100, 0x1039f, 0x103d0, 0x1056f, 0x10857,
+ 0x1091f, 0x1093f, 0x10a50, 0x10a7f, 0x10af0, 0x10b39, 0x10b99, 0x10f55,
+ 0x11047, 0x110bb, 0x110be, 0x11140, 0x11174, 0x111c5, 0x111cd, 0x111db,
+ 0x111dd, 0x11238, 0x112a9, 0x1144b, 0x1145b, 0x1145d, 0x114c6, 0x115c1,
+ 0x11641, 0x11660, 0x1173c, 0x1183b, 0x119e2, 0x11a3f, 0x11a9a, 0x11a9e,
+ 0x11c41, 0x11c70, 0x11ef7, 0x11fff, 0x12470, 0x16a6e, 0x16af5, 0x16b37,
+ 0x16b44, 0x16e97, 0x16fe2, 0x1bc9f, 0x1da87, 0x1e95e};
+constexpr int kNumPunctuationRangesStart = ARRAYSIZE(kPunctuationRangesStart);
+constexpr char32 kPunctuationRangesEnd[] = {
+ 0x0023, 0x002a, 0x002f, 0x003b, 0x0040, 0x005d, 0x005f, 0x007b,
+ 0x007d, 0x00a1, 0x00a7, 0x00ab, 0x00b7, 0x00bb, 0x00bf, 0x037e,
+ 0x0387, 0x055f, 0x058a, 0x05be, 0x05c0, 0x05c3, 0x05c6, 0x05f4,
+ 0x060a, 0x060d, 0x061b, 0x061f, 0x066d, 0x06d4, 0x070d, 0x07f9,
+ 0x083e, 0x085e, 0x0965, 0x0970, 0x09fd, 0x0a76, 0x0af0, 0x0c77,
+ 0x0c84, 0x0df4, 0x0e4f, 0x0e5b, 0x0f12, 0x0f14, 0x0f3d, 0x0f85,
+ 0x0fd4, 0x0fda, 0x104f, 0x10fb, 0x1368, 0x1400, 0x166e, 0x169c,
+ 0x16ed, 0x1736, 0x17d6, 0x17da, 0x180a, 0x1945, 0x1a1f, 0x1aa6,
+ 0x1aad, 0x1b60, 0x1bff, 0x1c3f, 0x1c7f, 0x1cc7, 0x1cd3, 0x2027,
+ 0x2043, 0x2051, 0x205e, 0x207e, 0x208e, 0x230b, 0x232a, 0x2775,
+ 0x27c6, 0x27ef, 0x2998, 0x29db, 0x29fd, 0x2cfc, 0x2cff, 0x2d70,
+ 0x2e2e, 0x2e4f, 0x3003, 0x3011, 0x301f, 0x3030, 0x303d, 0x30a0,
+ 0x30fb, 0xa4ff, 0xa60f, 0xa673, 0xa67e, 0xa6f7, 0xa877, 0xa8cf,
+ 0xa8fa, 0xa8fc, 0xa92f, 0xa95f, 0xa9cd, 0xa9df, 0xaa5f, 0xaadf,
+ 0xaaf1, 0xabeb, 0xfd3f, 0xfe19, 0xfe52, 0xfe61, 0xfe63, 0xfe68,
+ 0xfe6b, 0xff03, 0xff0a, 0xff0f, 0xff1b, 0xff20, 0xff3d, 0xff3f,
+ 0xff5b, 0xff5d, 0xff65, 0x10102, 0x1039f, 0x103d0, 0x1056f, 0x10857,
+ 0x1091f, 0x1093f, 0x10a58, 0x10a7f, 0x10af6, 0x10b3f, 0x10b9c, 0x10f59,
+ 0x1104d, 0x110bc, 0x110c1, 0x11143, 0x11175, 0x111c8, 0x111cd, 0x111db,
+ 0x111df, 0x1123d, 0x112a9, 0x1144f, 0x1145b, 0x1145d, 0x114c6, 0x115d7,
+ 0x11643, 0x1166c, 0x1173e, 0x1183b, 0x119e2, 0x11a46, 0x11a9c, 0x11aa2,
+ 0x11c45, 0x11c71, 0x11ef8, 0x11fff, 0x12474, 0x16a6f, 0x16af5, 0x16b3b,
+ 0x16b44, 0x16e9a, 0x16fe2, 0x1bc9f, 0x1da8b, 0x1e95f};
+constexpr int kNumPunctuationRangesEnd = ARRAYSIZE(kPunctuationRangesEnd);
+
// grep -E "Lu" UnicodeData.txt | sed -re "s/([0-9A-Z]+);.*/0x\1, /"
// There are three common ways in which upper/lower case codepoint ranges
// were introduced: one offs, dense ranges, and ranges that alternate between
@@ -180,15 +232,38 @@
0x2c63, 0x2c64, 0x2c6d, 0x2c6e, 0x2c6f, 0x2c70, 0xa77d, 0xa78d, 0xa7aa,
0xa7ab, 0xa7ac, 0xa7ad, 0xa7ae, 0xa7b0, 0xa7b1, 0xa7b2, 0xa7b3};
constexpr int kNumToLowerSingles = ARRAYSIZE(kToLowerSingles);
+constexpr int kToLowerSinglesOffsets[] = {
+ -199, -121, 210, 206, 1, 79, 202, 203, 1,
+ 207, 211, 209, 1, 211, 213, 214, 218, 218,
+ 218, 219, -97, -56, -130, 10795, -163, 10792, -195,
+ 69, 71, 116, 38, 64, 8, -60, -7, 15,
+ -7615, -7, -7517, -8383, -8262, 28, 1, 1, -10743,
+ -3814, -10727, -10780, -10749, -10783, -10782, -35332, -42280, -42308,
+ -42319, -42315, -42305, -42308, -42258, -42282, -42261, 928};
+constexpr int kNumToLowerSinglesOffsets = ARRAYSIZE(kToLowerSinglesOffsets);
constexpr int kToUpperSingles[] = {
- 0x0069, 0x00ff, 0x0253, 0x0254, 0x018c, 0x01dd, 0x0259, 0x025b, 0x0192,
- 0x0263, 0x0269, 0x0268, 0x0199, 0x026f, 0x0272, 0x0275, 0x0280, 0x0283,
- 0x0288, 0x0292, 0x0195, 0x01bf, 0x019e, 0x2c65, 0x019a, 0x2c66, 0x0180,
- 0x0289, 0x028c, 0x03f3, 0x03ac, 0x03cc, 0x03d7, 0x03b8, 0x03f2, 0x04cf,
- 0x00df, 0x1fe5, 0x03c9, 0x006b, 0x00e5, 0x214e, 0x2184, 0x2c61, 0x026b,
- 0x1d7d, 0x027d, 0x0251, 0x0271, 0x0250, 0x0252, 0x1d79, 0x0265, 0x0266,
- 0x025c, 0x0261, 0x026c, 0x026a, 0x029e, 0x0287, 0x029d, 0xab53};
+ 0x00b5, 0x00ff, 0x0131, 0x017f, 0x0180, 0x0195, 0x0199, 0x019a, 0x019e,
+ 0x01bf, 0x01dd, 0x01f3, 0x0250, 0x0251, 0x0252, 0x0253, 0x0254, 0x0259,
+ 0x025b, 0x025c, 0x0260, 0x0261, 0x0263, 0x0265, 0x0266, 0x0268, 0x0269,
+ 0x026a, 0x026b, 0x026c, 0x026f, 0x0271, 0x0272, 0x0275, 0x027d, 0x0280,
+ 0x0282, 0x0283, 0x0287, 0x0288, 0x0289, 0x028c, 0x0292, 0x029d, 0x029e,
+ 0x03ac, 0x03c2, 0x03cc, 0x03d0, 0x03d1, 0x03d5, 0x03d6, 0x03d7, 0x03f0,
+ 0x03f1, 0x03f2, 0x03f3, 0x03f5, 0x04cf, 0x1c80, 0x1c81, 0x1c82, 0x1c85,
+ 0x1c86, 0x1c87, 0x1c88, 0x1d79, 0x1d7d, 0x1d8e, 0x1e9b, 0x1fb3, 0x1fbe,
+ 0x1fc3, 0x1fe5, 0x1ff3, 0x214e, 0x2184, 0x2c61, 0x2c65, 0x2c66, 0xa794,
+ 0xab53};
constexpr int kNumToUpperSingles = ARRAYSIZE(kToUpperSingles);
+constexpr int kToUpperSinglesOffsets[] = {
+ 743, 121, -232, -300, 195, 97, -1, 163, 130, 56,
+ -79, -2, 10783, 10780, 10782, -210, -206, -202, -203, 42319,
+ -205, 42315, -207, 42280, 42308, -209, -211, 42308, 10743, 42305,
+ -211, 10749, -213, -214, 10727, -218, 42307, -218, 42282, -218,
+ -69, -71, -219, 42261, 42258, -38, -31, -64, -62, -57,
+ -47, -54, -8, -86, -80, 7, -116, -96, -15, -6254,
+ -6253, -6244, -6243, -6236, -6181, 35266, 35332, 3814, 35384, -59,
+ 9, -7205, 9, 7, 9, -28, -1, -1, -10795, -10792,
+ 48, -928};
+constexpr int kNumToUpperSinglesOffsets = ARRAYSIZE(kToUpperSinglesOffsets);
constexpr int kToLowerRangesStart[] = {
0x0041, 0x0100, 0x0189, 0x01a0, 0x01b1, 0x01b3, 0x0388, 0x038e, 0x0391,
0x03d8, 0x03fd, 0x0400, 0x0410, 0x0460, 0x0531, 0x10a0, 0x13a0, 0x13f0,
@@ -207,21 +282,23 @@
-8, -112, -128, -126, 48, 1, -10815, 1, 32, 40, 64, 32};
constexpr int kNumToLowerRangesOffsets = ARRAYSIZE(kToLowerRangesOffsets);
constexpr int kToUpperRangesStart[] = {
- 0x0061, 0x0101, 0x01a1, 0x01b4, 0x023f, 0x0256, 0x028a, 0x037b, 0x03ad,
- 0x03b1, 0x03cd, 0x03d9, 0x0430, 0x0450, 0x0461, 0x0561, 0x13f8, 0x1e01,
- 0x1f00, 0x1f70, 0x1f72, 0x1f76, 0x1f78, 0x1f7a, 0x1f7c, 0x1fd0, 0x1fe0,
- 0x2c30, 0x2c68, 0x2c81, 0x2d00, 0xab70, 0xff41, 0x10428, 0x10cc0, 0x118c0};
+ 0x0061, 0x0101, 0x01c6, 0x01ce, 0x023f, 0x0242, 0x0256, 0x028a,
+ 0x0371, 0x037b, 0x03ad, 0x03b1, 0x03cd, 0x03d9, 0x0430, 0x0450,
+ 0x0461, 0x0561, 0x10d0, 0x13f8, 0x1c83, 0x1e01, 0x1f00, 0x1f70,
+ 0x1f72, 0x1f76, 0x1f78, 0x1f7a, 0x1f7c, 0x1f80, 0x2c30, 0x2c68,
+ 0x2d00, 0xa641, 0xab70, 0xff41, 0x10428, 0x10cc0, 0x118c0};
constexpr int kNumToUpperRangesStart = ARRAYSIZE(kToUpperRangesStart);
constexpr int kToUpperRangesEnd[] = {
- 0x00fe, 0x0188, 0x01b0, 0x0387, 0x0240, 0x026c, 0x028b, 0x037d, 0x03b0,
- 0x03ef, 0x03ce, 0x03fb, 0x044f, 0x045f, 0x052f, 0x0586, 0x13fd, 0x1eff,
- 0x1fb1, 0x1f71, 0x1f75, 0x1f77, 0x1f79, 0x1f7c, 0x2105, 0x1fd1, 0x1fe1,
- 0x2c94, 0x2c76, 0xa7b7, 0x2d2d, 0xabbf, 0xff5a, 0x104fb, 0x10cf2, 0x118df};
+ 0x00fe, 0x01bd, 0x01cc, 0x023c, 0x0240, 0x024f, 0x0257, 0x028b,
+ 0x0377, 0x037d, 0x03af, 0x03cb, 0x03ce, 0x03fb, 0x044f, 0x045f,
+ 0x052f, 0x0586, 0x10ff, 0x13fd, 0x1c84, 0x1eff, 0x1f67, 0x1f71,
+ 0x1f75, 0x1f77, 0x1f79, 0x1f7b, 0x1f7d, 0x1fe1, 0x2c5e, 0x2cf3,
+ 0x2d2d, 0xa7c3, 0xabbf, 0xff5a, 0x104fb, 0x10cf2, 0x16e7f};
constexpr int kNumToUpperRangesEnd = ARRAYSIZE(kToUpperRangesEnd);
constexpr int kToUpperRangesOffsets[]{
- -32, -1, -1, -1, 10815, -205, -217, 130, -37, -32, -63, -1,
- -32, -80, -1, -48, -8, -1, 8, 74, 86, 100, 128, 112,
- 126, 8, 8, -48, -1, -1, -7264, -38864, -32, -40, -64, -32};
+ -32, -1, -2, -1, 10815, -1, -205, -217, -1, 130, -37, -32, -63,
+ -1, -32, -80, -1, -48, 3008, -8, -6242, -1, 8, 74, 86, 100,
+ 128, 112, 126, 8, -48, -1, -7264, -1, -38864, -32, -40, -64, -32};
constexpr int kNumToUpperRangesOffsets = ARRAYSIZE(kToUpperRangesOffsets);
#undef ARRAYSIZE
@@ -236,16 +313,20 @@
"number of uppercase stride 1 range starts/ends doesn't match");
static_assert(kNumUpperRanges2Start == kNumUpperRanges2End,
"number of uppercase stride 2 range starts/ends doesn't match");
-static_assert(kNumToLowerSingles == kNumToUpperSingles,
- "number of to lower and upper singles doesn't match");
+static_assert(kNumToLowerSingles == kNumToLowerSinglesOffsets,
+ "number of to lower singles and offsets doesn't match");
static_assert(kNumToLowerRangesStart == kNumToLowerRangesEnd,
"mismatching number of range starts/ends for to lower ranges");
static_assert(kNumToLowerRangesStart == kNumToLowerRangesOffsets,
"number of to lower ranges and offsets doesn't match");
+static_assert(kNumToUpperSingles == kNumToUpperSinglesOffsets,
+ "number of to upper singles and offsets doesn't match");
static_assert(kNumToUpperRangesStart == kNumToUpperRangesEnd,
"mismatching number of range starts/ends for to upper ranges");
static_assert(kNumToUpperRangesStart == kNumToUpperRangesOffsets,
"number of to upper ranges and offsets doesn't match");
+static_assert(kNumPunctuationRangesStart == kNumPunctuationRangesEnd,
+ "mismatch number of start/ends for punctuation ranges.");
constexpr int kNoMatch = -1;
@@ -357,6 +438,12 @@
}
}
+bool IsPunctuation(char32 codepoint) {
+ return (GetOverlappingRangeIndex(
+ kPunctuationRangesStart, kPunctuationRangesEnd,
+ kNumPunctuationRangesStart, /*stride=*/1, codepoint) >= 0);
+}
+
char32 ToLower(char32 codepoint) {
// Make sure we still produce output even if the method is called for a
// codepoint that's not an uppercase character.
@@ -366,7 +453,7 @@
const int singles_idx =
GetMatchIndex(kToLowerSingles, kNumToLowerSingles, codepoint);
if (singles_idx >= 0) {
- return kToUpperSingles[singles_idx];
+ return codepoint + kToLowerSinglesOffsets[singles_idx];
}
const int ranges_idx =
GetOverlappingRangeIndex(kToLowerRangesStart, kToLowerRangesEnd,
@@ -386,7 +473,7 @@
const int singles_idx =
GetMatchIndex(kToUpperSingles, kNumToUpperSingles, codepoint);
if (singles_idx >= 0) {
- return kToLowerSingles[singles_idx];
+ return codepoint + kToUpperSinglesOffsets[singles_idx];
}
const int ranges_idx =
GetOverlappingRangeIndex(kToUpperRangesStart, kToUpperRangesEnd,
diff --git a/native/utils/utf8/unilib-common.h b/native/utils/utf8/unilib-common.h
index 0394cc3..442f256 100644
--- a/native/utils/utf8/unilib-common.h
+++ b/native/utils/utf8/unilib-common.h
@@ -18,6 +18,7 @@
#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_COMMON_H_
#include "utils/base/integral_types.h"
+#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -27,6 +28,7 @@
bool IsDigit(char32 codepoint);
bool IsLower(char32 codepoint);
bool IsUpper(char32 codepoint);
+bool IsPunctuation(char32 codepoint);
char32 ToLower(char32 codepoint);
char32 ToUpper(char32 codepoint);
char32 GetPairedBracket(char32 codepoint);
diff --git a/native/utils/utf8/unilib-javaicu.cc b/native/utils/utf8/unilib-javaicu.cc
index 13bb536..64a4434 100644
--- a/native/utils/utf8/unilib-javaicu.cc
+++ b/native/utils/utf8/unilib-javaicu.cc
@@ -20,51 +20,58 @@
#include <cctype>
#include <map>
+#include "utils/base/statusor.h"
+#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
#include "utils/java/string_utils.h"
#include "utils/utf8/unilib-common.h"
namespace libtextclassifier3 {
-UniLib::UniLib() {
+UniLibBase::UniLibBase() {
TC3_LOG(FATAL) << "Java ICU UniLib must be initialized with a JniCache.";
}
-UniLib::UniLib(const std::shared_ptr<JniCache>& jni_cache)
+UniLibBase::UniLibBase(const std::shared_ptr<JniCache>& jni_cache)
: jni_cache_(jni_cache) {}
-bool UniLib::IsOpeningBracket(char32 codepoint) const {
+bool UniLibBase::IsOpeningBracket(char32 codepoint) const {
return libtextclassifier3::IsOpeningBracket(codepoint);
}
-bool UniLib::IsClosingBracket(char32 codepoint) const {
+bool UniLibBase::IsClosingBracket(char32 codepoint) const {
return libtextclassifier3::IsClosingBracket(codepoint);
}
-bool UniLib::IsWhitespace(char32 codepoint) const {
+bool UniLibBase::IsWhitespace(char32 codepoint) const {
return libtextclassifier3::IsWhitespace(codepoint);
}
-bool UniLib::IsDigit(char32 codepoint) const {
+bool UniLibBase::IsDigit(char32 codepoint) const {
return libtextclassifier3::IsDigit(codepoint);
}
-bool UniLib::IsLower(char32 codepoint) const {
+bool UniLibBase::IsLower(char32 codepoint) const {
return libtextclassifier3::IsLower(codepoint);
}
-bool UniLib::IsUpper(char32 codepoint) const {
+bool UniLibBase::IsUpper(char32 codepoint) const {
return libtextclassifier3::IsUpper(codepoint);
}
-char32 UniLib::ToLower(char32 codepoint) const {
+bool UniLibBase::IsPunctuation(char32 codepoint) const {
+ return libtextclassifier3::IsPunctuation(codepoint);
+}
+
+char32 UniLibBase::ToLower(char32 codepoint) const {
return libtextclassifier3::ToLower(codepoint);
}
-char32 UniLib::ToUpper(char32 codepoint) const {
+char32 UniLibBase::ToUpper(char32 codepoint) const {
return libtextclassifier3::ToUpper(codepoint);
}
-char32 UniLib::GetPairedBracket(char32 codepoint) const {
+char32 UniLibBase::GetPairedBracket(char32 codepoint) const {
return libtextclassifier3::GetPairedBracket(codepoint);
}
@@ -72,37 +79,34 @@
// Implementations that call out to JVM. Behold the beauty.
// -----------------------------------------------------------------------------
-bool UniLib::ParseInt32(const UnicodeText& text, int* result) const {
+bool UniLibBase::ParseInt32(const UnicodeText& text, int* result) const {
if (jni_cache_) {
JNIEnv* env = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> text_java =
- jni_cache_->ConvertToJavaString(text);
- jint res = env->CallStaticIntMethod(jni_cache_->integer_class.get(),
- jni_cache_->integer_parse_int,
- text_java.get());
- if (jni_cache_->ExceptionCheckAndClear()) {
- return false;
- }
- *result = res;
+ TC3_ASSIGN_OR_RETURN_FALSE(const ScopedLocalRef<jstring> text_java,
+ jni_cache_->ConvertToJavaString(text));
+ TC3_ASSIGN_OR_RETURN_FALSE(
+ *result, JniHelper::CallStaticIntMethod(
+ env, jni_cache_->integer_class.get(),
+ jni_cache_->integer_parse_int, text_java.get()));
return true;
}
return false;
}
-std::unique_ptr<UniLib::RegexPattern> UniLib::CreateRegexPattern(
+std::unique_ptr<UniLibBase::RegexPattern> UniLibBase::CreateRegexPattern(
const UnicodeText& regex) const {
- return std::unique_ptr<UniLib::RegexPattern>(
- new UniLib::RegexPattern(jni_cache_.get(), regex, /*lazy=*/false));
+ return std::unique_ptr<UniLibBase::RegexPattern>(
+ new UniLibBase::RegexPattern(jni_cache_.get(), regex, /*lazy=*/false));
}
-std::unique_ptr<UniLib::RegexPattern> UniLib::CreateLazyRegexPattern(
+std::unique_ptr<UniLibBase::RegexPattern> UniLibBase::CreateLazyRegexPattern(
const UnicodeText& regex) const {
- return std::unique_ptr<UniLib::RegexPattern>(
- new UniLib::RegexPattern(jni_cache_.get(), regex, /*lazy=*/true));
+ return std::unique_ptr<UniLibBase::RegexPattern>(
+ new UniLibBase::RegexPattern(jni_cache_.get(), regex, /*lazy=*/true));
}
-UniLib::RegexPattern::RegexPattern(const JniCache* jni_cache,
- const UnicodeText& pattern, bool lazy)
+UniLibBase::RegexPattern::RegexPattern(const JniCache* jni_cache,
+ const UnicodeText& pattern, bool lazy)
: jni_cache_(jni_cache),
pattern_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
initialized_(false),
@@ -113,36 +117,37 @@
}
}
-void UniLib::RegexPattern::LockedInitializeIfNotAlready() const {
+Status UniLibBase::RegexPattern::LockedInitializeIfNotAlready() const {
std::lock_guard<std::mutex> guard(mutex_);
if (initialized_ || initialization_failure_) {
- return;
+ return Status::OK;
}
if (jni_cache_) {
JNIEnv* jenv = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> regex_java =
- jni_cache_->ConvertToJavaString(pattern_text_);
- pattern_ = MakeGlobalRef(jenv->CallStaticObjectMethod(
- jni_cache_->pattern_class.get(),
- jni_cache_->pattern_compile, regex_java.get()),
- jenv, jni_cache_->jvm);
-
- if (jni_cache_->ExceptionCheckAndClear() || pattern_ == nullptr) {
- initialization_failure_ = true;
- pattern_.reset();
- return;
+ initialization_failure_ = true;
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> regex_java,
+ jni_cache_->ConvertToJavaString(pattern_text_));
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobject> pattern,
+ JniHelper::CallStaticObjectMethod(
+ jenv, jni_cache_->pattern_class.get(),
+ jni_cache_->pattern_compile, regex_java.get()));
+ pattern_ = MakeGlobalRef(pattern.get(), jenv, jni_cache_->jvm);
+ if (pattern_ == nullptr) {
+ return Status::UNKNOWN;
}
+ initialization_failure_ = false;
initialized_ = true;
pattern_text_.clear(); // We don't need this anymore.
}
+ return Status::OK;
}
-constexpr int UniLib::RegexMatcher::kError;
-constexpr int UniLib::RegexMatcher::kNoError;
+constexpr int UniLibBase::RegexMatcher::kError;
+constexpr int UniLibBase::RegexMatcher::kNoError;
-std::unique_ptr<UniLib::RegexMatcher> UniLib::RegexPattern::Matcher(
+std::unique_ptr<UniLibBase::RegexMatcher> UniLibBase::RegexPattern::Matcher(
const UnicodeText& context) const {
LockedInitializeIfNotAlready(); // Possibly lazy initialization.
if (initialization_failure_) {
@@ -151,35 +156,41 @@
if (jni_cache_) {
JNIEnv* env = jni_cache_->GetEnv();
- const jstring context_java =
- jni_cache_->ConvertToJavaString(context).release();
- if (!context_java) {
+ const StatusOr<ScopedLocalRef<jstring>> status_or_context_java =
+ jni_cache_->ConvertToJavaString(context);
+ if (!status_or_context_java.ok() || !status_or_context_java.ValueOrDie()) {
return nullptr;
}
- const jobject matcher = env->CallObjectMethod(
- pattern_.get(), jni_cache_->pattern_matcher, context_java);
- if (jni_cache_->ExceptionCheckAndClear() || !matcher) {
+ const StatusOr<ScopedLocalRef<jobject>> status_or_matcher =
+ JniHelper::CallObjectMethod(env, pattern_.get(),
+ jni_cache_->pattern_matcher,
+ status_or_context_java.ValueOrDie().get());
+ if (jni_cache_->ExceptionCheckAndClear() || !status_or_matcher.ok() ||
+ !status_or_matcher.ValueOrDie()) {
return nullptr;
}
- return std::unique_ptr<UniLib::RegexMatcher>(new RegexMatcher(
- jni_cache_, MakeGlobalRef(matcher, env, jni_cache_->jvm),
- MakeGlobalRef(context_java, env, jni_cache_->jvm)));
+ return std::unique_ptr<UniLibBase::RegexMatcher>(new RegexMatcher(
+ jni_cache_,
+ MakeGlobalRef(status_or_matcher.ValueOrDie().get(), env,
+ jni_cache_->jvm),
+ MakeGlobalRef(status_or_context_java.ValueOrDie().get(), env,
+ jni_cache_->jvm)));
} else {
// NOTE: A valid object needs to be created here to pass the interface
// tests.
- return std::unique_ptr<UniLib::RegexMatcher>(
- new RegexMatcher(jni_cache_, nullptr, nullptr));
+ return std::unique_ptr<UniLibBase::RegexMatcher>(
+ new RegexMatcher(jni_cache_, {}, {}));
}
}
-UniLib::RegexMatcher::RegexMatcher(const JniCache* jni_cache,
- ScopedGlobalRef<jobject> matcher,
- ScopedGlobalRef<jstring> text)
+UniLibBase::RegexMatcher::RegexMatcher(const JniCache* jni_cache,
+ ScopedGlobalRef<jobject> matcher,
+ ScopedGlobalRef<jstring> text)
: jni_cache_(jni_cache),
matcher_(std::move(matcher)),
text_(std::move(text)) {}
-bool UniLib::RegexMatcher::Matches(int* status) const {
+bool UniLibBase::RegexMatcher::Matches(int* status) const {
if (jni_cache_) {
*status = kNoError;
const bool result = jni_cache_->GetEnv()->CallBooleanMethod(
@@ -195,7 +206,7 @@
}
}
-bool UniLib::RegexMatcher::ApproximatelyMatches(int* status) {
+bool UniLibBase::RegexMatcher::ApproximatelyMatches(int* status) {
*status = kNoError;
jni_cache_->GetEnv()->CallObjectMethod(matcher_.get(),
@@ -237,7 +248,7 @@
return true;
}
-bool UniLib::RegexMatcher::UpdateLastFindOffset() const {
+bool UniLibBase::RegexMatcher::UpdateLastFindOffset() const {
if (!last_find_offset_dirty_) {
return true;
}
@@ -262,7 +273,7 @@
return true;
}
-bool UniLib::RegexMatcher::Find(int* status) {
+bool UniLibBase::RegexMatcher::Find(int* status) {
if (jni_cache_) {
const bool result = jni_cache_->GetEnv()->CallBooleanMethod(
matcher_.get(), jni_cache_->matcher_find);
@@ -280,11 +291,11 @@
}
}
-int UniLib::RegexMatcher::Start(int* status) const {
+int UniLibBase::RegexMatcher::Start(int* status) const {
return Start(/*group_idx=*/0, status);
}
-int UniLib::RegexMatcher::Start(int group_idx, int* status) const {
+int UniLibBase::RegexMatcher::Start(int group_idx, int* status) const {
if (jni_cache_) {
*status = kNoError;
@@ -320,11 +331,11 @@
}
}
-int UniLib::RegexMatcher::End(int* status) const {
+int UniLibBase::RegexMatcher::End(int* status) const {
return End(/*group_idx=*/0, status);
}
-int UniLib::RegexMatcher::End(int group_idx, int* status) const {
+int UniLibBase::RegexMatcher::End(int group_idx, int* status) const {
if (jni_cache_) {
*status = kNoError;
@@ -360,20 +371,26 @@
}
}
-UnicodeText UniLib::RegexMatcher::Group(int* status) const {
+UnicodeText UniLibBase::RegexMatcher::Group(int* status) const {
if (jni_cache_) {
JNIEnv* jenv = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> java_result(
- reinterpret_cast<jstring>(
- jenv->CallObjectMethod(matcher_.get(), jni_cache_->matcher_group)),
- jenv);
- if (jni_cache_->ExceptionCheckAndClear() || !java_result) {
+ StatusOr<ScopedLocalRef<jstring>> status_or_java_result =
+ JniHelper::CallObjectMethod<jstring>(jenv, matcher_.get(),
+ jni_cache_->matcher_group);
+
+ if (jni_cache_->ExceptionCheckAndClear() || !status_or_java_result.ok() ||
+ !status_or_java_result.ValueOrDie()) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
std::string result;
- if (!JStringToUtf8String(jenv, java_result.get(), &result)) {
+ if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
+ &result)) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ if (result.empty()) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
@@ -385,14 +402,14 @@
}
}
-UnicodeText UniLib::RegexMatcher::Group(int group_idx, int* status) const {
+UnicodeText UniLibBase::RegexMatcher::Group(int group_idx, int* status) const {
if (jni_cache_) {
JNIEnv* jenv = jni_cache_->GetEnv();
- const ScopedLocalRef<jstring> java_result(
- reinterpret_cast<jstring>(jenv->CallObjectMethod(
- matcher_.get(), jni_cache_->matcher_group_idx, group_idx)),
- jenv);
- if (jni_cache_->ExceptionCheckAndClear()) {
+
+ StatusOr<ScopedLocalRef<jstring>> status_or_java_result =
+ JniHelper::CallObjectMethod<jstring>(
+ jenv, matcher_.get(), jni_cache_->matcher_group_idx, group_idx);
+ if (jni_cache_->ExceptionCheckAndClear() || !status_or_java_result.ok()) {
*status = kError;
TC3_LOG(ERROR) << "Exception occurred";
return UTF8ToUnicodeText("", /*do_copy=*/false);
@@ -401,13 +418,18 @@
// java_result is nullptr when the group did not participate in the match.
// For these cases other UniLib implementations return empty string, and
// the participation can be checked by checking if Start() == -1.
- if (!java_result) {
+ if (!status_or_java_result.ValueOrDie()) {
*status = kNoError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
std::string result;
- if (!JStringToUtf8String(jenv, java_result.get(), &result)) {
+ if (!JStringToUtf8String(jenv, status_or_java_result.ValueOrDie().get(),
+ &result)) {
+ *status = kError;
+ return UTF8ToUnicodeText("", /*do_copy=*/false);
+ }
+ if (result.empty()) {
*status = kError;
return UTF8ToUnicodeText("", /*do_copy=*/false);
}
@@ -419,10 +441,10 @@
}
}
-constexpr int UniLib::BreakIterator::kDone;
+constexpr int UniLibBase::BreakIterator::kDone;
-UniLib::BreakIterator::BreakIterator(const JniCache* jni_cache,
- const UnicodeText& text)
+UniLibBase::BreakIterator::BreakIterator(const JniCache* jni_cache,
+ const UnicodeText& text)
: jni_cache_(jni_cache),
text_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
iterator_(nullptr, jni_cache ? jni_cache->jvm : nullptr),
@@ -430,26 +452,36 @@
last_unicode_index_(0) {
if (jni_cache_) {
JNIEnv* jenv = jni_cache_->GetEnv();
- text_ = MakeGlobalRef(jni_cache_->ConvertToJavaString(text).release(), jenv,
- jni_cache->jvm);
+ StatusOr<ScopedLocalRef<jstring>> status_or_text =
+ jni_cache_->ConvertToJavaString(text);
+ if (!status_or_text.ok()) {
+ return;
+ }
+ text_ =
+ MakeGlobalRef(status_or_text.ValueOrDie().get(), jenv, jni_cache->jvm);
if (!text_) {
return;
}
- iterator_ = MakeGlobalRef(
- jenv->CallStaticObjectMethod(jni_cache->breakiterator_class.get(),
- jni_cache->breakiterator_getwordinstance,
- jni_cache->locale_us.get()),
- jenv, jni_cache->jvm);
+ StatusOr<ScopedLocalRef<jobject>> status_or_iterator =
+ JniHelper::CallStaticObjectMethod(
+ jenv, jni_cache->breakiterator_class.get(),
+ jni_cache->breakiterator_getwordinstance,
+ jni_cache->locale_us.get());
+ if (!status_or_iterator.ok()) {
+ return;
+ }
+ iterator_ = MakeGlobalRef(status_or_iterator.ValueOrDie().get(), jenv,
+ jni_cache->jvm);
if (!iterator_) {
return;
}
- jenv->CallVoidMethod(iterator_.get(), jni_cache->breakiterator_settext,
- text_.get());
+ JniHelper::CallVoidMethod(jenv, iterator_.get(),
+ jni_cache->breakiterator_settext, text_.get());
}
}
-int UniLib::BreakIterator::Next() {
+int UniLibBase::BreakIterator::Next() {
if (jni_cache_) {
const int break_index = jni_cache_->GetEnv()->CallIntMethod(
iterator_.get(), jni_cache_->breakiterator_next);
@@ -471,10 +503,10 @@
return BreakIterator::kDone;
}
-std::unique_ptr<UniLib::BreakIterator> UniLib::CreateBreakIterator(
+std::unique_ptr<UniLibBase::BreakIterator> UniLibBase::CreateBreakIterator(
const UnicodeText& text) const {
- return std::unique_ptr<UniLib::BreakIterator>(
- new UniLib::BreakIterator(jni_cache_.get(), text));
+ return std::unique_ptr<UniLibBase::BreakIterator>(
+ new UniLibBase::BreakIterator(jni_cache_.get(), text));
}
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib-javaicu.h b/native/utils/utf8/unilib-javaicu.h
index 77f4970..549aed5 100644
--- a/native/utils/utf8/unilib-javaicu.h
+++ b/native/utils/utf8/unilib-javaicu.h
@@ -22,23 +22,23 @@
#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_JAVAICU_H_
#include <jni.h>
+
#include <memory>
#include <mutex> // NOLINT
#include <string>
#include "utils/base/integral_types.h"
+#include "utils/java/jni-base.h"
#include "utils/java/jni-cache.h"
-#include "utils/java/scoped_global_ref.h"
-#include "utils/java/scoped_local_ref.h"
#include "utils/java/string_utils.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
-class UniLib {
+class UniLibBase {
public:
- UniLib();
- explicit UniLib(const std::shared_ptr<JniCache>& jni_cache);
+ UniLibBase();
+ explicit UniLibBase(const std::shared_ptr<JniCache>& jni_cache);
bool ParseInt32(const UnicodeText& text, int* result) const;
bool IsOpeningBracket(char32 codepoint) const;
@@ -47,6 +47,7 @@
bool IsDigit(char32 codepoint) const;
bool IsLower(char32 codepoint) const;
bool IsUpper(char32 codepoint) const;
+ bool IsPunctuation(char32 codepoint) const;
char32 ToLower(char32 codepoint) const;
char32 ToUpper(char32 codepoint) const;
@@ -134,10 +135,10 @@
std::unique_ptr<RegexMatcher> Matcher(const UnicodeText& context) const;
private:
- friend class UniLib;
+ friend class UniLibBase;
RegexPattern(const JniCache* jni_cache, const UnicodeText& pattern,
bool lazy);
- void LockedInitializeIfNotAlready() const;
+ Status LockedInitializeIfNotAlready() const;
const JniCache* jni_cache_;
@@ -159,7 +160,7 @@
static constexpr int kDone = -1;
private:
- friend class UniLib;
+ friend class UniLibBase;
BreakIterator(const JniCache* jni_cache, const UnicodeText& text);
const JniCache* jni_cache_;
diff --git a/native/utils/utf8/unilib.h b/native/utils/utf8/unilib.h
index ec1f329..4f09b5c 100644
--- a/native/utils/utf8/unilib.h
+++ b/native/utils/utf8/unilib.h
@@ -17,7 +17,45 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
+#include "utils/base/integral_types.h"
+#include "utils/utf8/unicodetext.h"
+
#include "utils/utf8/unilib-javaicu.h"
#define INIT_UNILIB_FOR_TESTING(VAR) VAR(nullptr)
+namespace libtextclassifier3 {
+
+class UniLib : public UniLibBase {
+ public:
+ using UniLibBase::UniLibBase;
+
+ // Lowercase a unicode string.
+ UnicodeText ToLowerText(const UnicodeText& text) const {
+ UnicodeText result;
+ for (const char32 codepoint : text) {
+ result.push_back(ToLower(codepoint));
+ }
+ return result;
+ }
+
+ // Uppercase a unicode string.
+ UnicodeText ToUpperText(const UnicodeText& text) const {
+ UnicodeText result;
+ for (const char32 codepoint : text) {
+ result.push_back(UniLibBase::ToUpper(codepoint));
+ }
+ return result;
+ }
+
+ bool IsDigits(const UnicodeText& text) const {
+ for (const char32 codepoint : text) {
+ if (!IsDigit(codepoint)) {
+ return false;
+ }
+ }
+ return true;
+ }
+};
+
+} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_H_
diff --git a/native/utils/utf8/unilib_test-include.cc b/native/utils/utf8/unilib_test-include.cc
index 9465ea6..f931ccc 100644
--- a/native/utils/utf8/unilib_test-include.cc
+++ b/native/utils/utf8/unilib_test-include.cc
@@ -16,7 +16,7 @@
#include "utils/utf8/unilib_test-include.h"
-#include "utils/utf8/unicodetext.h"
+#include "utils/base/logging.h"
#include "gmock/gmock.h"
namespace libtextclassifier3 {
@@ -38,12 +38,24 @@
EXPECT_FALSE(unilib_.IsLower(')'));
EXPECT_TRUE(unilib_.IsLower('a'));
EXPECT_TRUE(unilib_.IsLower('z'));
+ EXPECT_TRUE(unilib_.IsPunctuation('!'));
+ EXPECT_TRUE(unilib_.IsPunctuation('?'));
+ EXPECT_TRUE(unilib_.IsPunctuation('#'));
+ EXPECT_TRUE(unilib_.IsPunctuation('('));
+ EXPECT_FALSE(unilib_.IsPunctuation('0'));
+ EXPECT_FALSE(unilib_.IsPunctuation('$'));
EXPECT_EQ(unilib_.ToLower('A'), 'a');
EXPECT_EQ(unilib_.ToLower('Z'), 'z');
EXPECT_EQ(unilib_.ToLower(')'), ')');
+ EXPECT_EQ(unilib_.ToLowerText(UTF8ToUnicodeText("Never gonna give you up."))
+ .ToUTF8String(),
+ "never gonna give you up.");
EXPECT_EQ(unilib_.ToUpper('a'), 'A');
EXPECT_EQ(unilib_.ToUpper('z'), 'Z');
EXPECT_EQ(unilib_.ToUpper(')'), ')');
+ EXPECT_EQ(unilib_.ToUpperText(UTF8ToUnicodeText("Never gonna let you down."))
+ .ToUTF8String(),
+ "NEVER GONNA LET YOU DOWN.");
EXPECT_EQ(unilib_.GetPairedBracket(')'), '(');
EXPECT_EQ(unilib_.GetPairedBracket('}'), '{');
}
@@ -60,24 +72,38 @@
EXPECT_FALSE(unilib_.IsUpper(0x0211)); // SMALL R WITH DOUBLE GRAVE
EXPECT_TRUE(unilib_.IsUpper(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
EXPECT_TRUE(unilib_.IsUpper(0x0391)); // GREEK CAPITAL ALPHA
- EXPECT_TRUE(unilib_.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
- EXPECT_FALSE(unilib_.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
- EXPECT_TRUE(unilib_.IsLower(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
- EXPECT_TRUE(unilib_.IsLower(0x03B1)); // GREEK SMALL ALPHA
- EXPECT_TRUE(unilib_.IsLower(0x03CB)); // GREEK SMALL UPSILON
- EXPECT_TRUE(unilib_.IsLower(0x0211)); // SMALL R WITH DOUBLE GRAVE
- EXPECT_TRUE(unilib_.IsLower(0x03C0)); // GREEK SMALL PI
- EXPECT_TRUE(unilib_.IsLower(0x007a)); // SMALL Z
- EXPECT_FALSE(unilib_.IsLower(0x005a)); // CAPITAL Z
- EXPECT_FALSE(unilib_.IsLower(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
- EXPECT_FALSE(unilib_.IsLower(0x0391)); // GREEK CAPITAL ALPHA
- EXPECT_EQ(unilib_.ToLower(0x0391), 0x03B1); // GREEK ALPHA
- EXPECT_EQ(unilib_.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
- EXPECT_EQ(unilib_.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
+ EXPECT_TRUE(unilib_.IsUpper(0x03AB)); // GREEK CAPITAL UPSILON W DIAL
+ EXPECT_FALSE(unilib_.IsUpper(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_.IsLower(0x03AC)); // GREEK SMALL ALPHA WITH TONOS
+ EXPECT_TRUE(unilib_.IsLower(0x03B1)); // GREEK SMALL ALPHA
+ EXPECT_TRUE(unilib_.IsLower(0x03CB)); // GREEK SMALL UPSILON
+ EXPECT_TRUE(unilib_.IsLower(0x0211)); // SMALL R WITH DOUBLE GRAVE
+ EXPECT_TRUE(unilib_.IsLower(0x03C0)); // GREEK SMALL PI
+ EXPECT_TRUE(unilib_.IsLower(0x007A)); // SMALL Z
+ EXPECT_FALSE(unilib_.IsLower(0x005A)); // CAPITAL Z
+ EXPECT_FALSE(unilib_.IsLower(0x0212)); // CAPITAL R WITH DOUBLE GRAVE
+ EXPECT_FALSE(unilib_.IsLower(0x0391)); // GREEK CAPITAL ALPHA
+ EXPECT_TRUE(unilib_.IsPunctuation(0x055E)); // ARMENIAN QUESTION MARK
+ EXPECT_TRUE(unilib_.IsPunctuation(0x066C)); // ARABIC THOUSANDS SEPARATOR
+ EXPECT_TRUE(unilib_.IsPunctuation(0x07F7)); // NKO SYMBOL GBAKURUNEN
+ EXPECT_TRUE(unilib_.IsPunctuation(0x10AF2)); // DOUBLE DOT WITHIN DOT
+ EXPECT_FALSE(unilib_.IsPunctuation(0x00A3)); // POUND SIGN
+ EXPECT_FALSE(unilib_.IsPunctuation(0xA838)); // NORTH INDIC RUPEE MARK
+ EXPECT_EQ(unilib_.ToLower(0x0391), 0x03B1); // GREEK ALPHA
+ EXPECT_EQ(unilib_.ToLower(0x03AB), 0x03CB); // GREEK UPSILON WITH DIALYTIKA
+ EXPECT_EQ(unilib_.ToLower(0x03C0), 0x03C0); // GREEK SMALL PI
+ EXPECT_EQ(unilib_.ToLower(0x03A3), 0x03C3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(unilib_.ToLowerText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
+ .ToUTF8String(),
+ "κανένας άνθρωπος δεν ξέρει");
EXPECT_EQ(unilib_.ToUpper(0x03B1), 0x0391); // GREEK ALPHA
EXPECT_EQ(unilib_.ToUpper(0x03CB), 0x03AB); // GREEK UPSILON WITH DIALYTIKA
EXPECT_EQ(unilib_.ToUpper(0x0391), 0x0391); // GREEK CAPITAL ALPHA
-
+ EXPECT_EQ(unilib_.ToUpper(0x03C3), 0x03A3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(unilib_.ToUpper(0x03C2), 0x03A3); // GREEK CAPITAL LETTER SIGMA
+ EXPECT_EQ(unilib_.ToUpperText(UTF8ToUnicodeText("Κανένας άνθρωπος δεν ξέρει"))
+ .ToUTF8String(),
+ "ΚΑΝΈΝΑΣ ΆΝΘΡΩΠΟΣ ΔΕΝ ΞΈΡΕΙ");
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3C), 0x0F3D);
EXPECT_EQ(unilib_.GetPairedBracket(0x0F3D), 0x0F3C);
}
@@ -235,6 +261,5 @@
EXPECT_FALSE(unilib_.ParseInt32(UTF8ToUnicodeText("1a3", /*do_copy=*/false),
&result));
}
-
} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib_test-include.h b/native/utils/utf8/unilib_test-include.h
index b4efcd6..342a00c 100644
--- a/native/utils/utf8/unilib_test-include.h
+++ b/native/utils/utf8/unilib_test-include.h
@@ -17,26 +17,21 @@
#ifndef LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
#define LIBTEXTCLASSIFIER_UTILS_UTF8_UNILIB_TEST_INCLUDE_H_
-// Include the version of UniLib depending on the macro.
+#include "utils/utf8/unilib.h"
+#include "gtest/gtest.h"
+
#if defined TC3_UNILIB_ICU
-#include "utils/utf8/unilib-icu.h"
#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
#elif defined TC3_UNILIB_JAVAICU
#include <jni.h>
extern JNIEnv* g_jenv;
#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR(JniCache::Create(g_jenv))
-#include "utils/utf8/unilib-javaicu.h"
#elif defined TC3_UNILIB_APPLE
-#include "utils/utf8/unilib-apple.h"
#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
#elif defined TC3_UNILIB_DUMMY
-#include "utils/utf8/unilib-dummy.h"
#define TC3_TESTING_CREATE_UNILIB_INSTANCE(VAR) VAR()
#endif
-#include "utils/base/logging.h"
-#include "gtest/gtest.h"
-
namespace libtextclassifier3 {
namespace test_internal {
diff --git a/native/utils/utf8/unilib_test.cc b/native/utils/utf8/unilib_test.cc
index b5658af..01b5164 100644
--- a/native/utils/utf8/unilib_test.cc
+++ b/native/utils/utf8/unilib_test.cc
@@ -14,7 +14,5 @@
* limitations under the License.
*/
-#include "gtest/gtest.h"
-
// The actual code of the test is in the following include:
#include "utils/utf8/unilib_test-include.h"