Import libtextclassifier
Test: atest TextClassifierServiceTest
Change-Id: Ief715193072d0af3aea230c3c9475ef18e8ac84c
diff --git a/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java b/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java
index 8d122ad..a1bf109 100644
--- a/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java
+++ b/java/src/com/android/textclassifier/ActionsModelParamsSupplier.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,202 +11,191 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
+
package com.android.textclassifier;
-import android.content.ContentResolver;
import android.content.Context;
import android.database.ContentObserver;
import android.provider.Settings;
-
import androidx.annotation.GuardedBy;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
import java.lang.ref.WeakReference;
import java.util.Objects;
import java.util.function.Supplier;
+import javax.annotation.Nullable;
/** Parses the {@link Settings.Global#TEXT_CLASSIFIER_ACTION_MODEL_PARAMS} flag. */
+// TODO(tonymak): Re-enable this class.
public final class ActionsModelParamsSupplier
- implements Supplier<ActionsModelParamsSupplier.ActionsModelParams> {
- private static final String TAG = "ActionsModelParamsSupplier";
+ implements Supplier<ActionsModelParamsSupplier.ActionsModelParams> {
+ private static final String TAG = "ActionsModelParams";
- @VisibleForTesting static final String KEY_REQUIRED_MODEL_VERSION = "required_model_version";
- @VisibleForTesting static final String KEY_REQUIRED_LOCALES = "required_locales";
+ @VisibleForTesting static final String KEY_REQUIRED_MODEL_VERSION = "required_model_version";
+ @VisibleForTesting static final String KEY_REQUIRED_LOCALES = "required_locales";
- @VisibleForTesting
- static final String KEY_SERIALIZED_PRECONDITIONS = "serialized_preconditions";
+ @VisibleForTesting static final String KEY_SERIALIZED_PRECONDITIONS = "serialized_preconditions";
- private final Context mAppContext;
- private final SettingsObserver mSettingsObserver;
+ private final Context appContext;
+ private final SettingsObserver settingsObserver;
- private final Object mLock = new Object();
- private final Runnable mOnChangedListener;
+ private final Object lock = new Object();
+ private final Runnable onChangedListener;
- @Nullable
- @GuardedBy("mLock")
- private ActionsModelParams mActionsModelParams;
+ @Nullable
+ @GuardedBy("lock")
+ private ActionsModelParams actionsModelParams;
- @GuardedBy("mLock")
- private boolean mParsed = true;
+ @GuardedBy("lock")
+ private boolean parsed = true;
- public ActionsModelParamsSupplier(Context context, @Nullable Runnable onChangedListener) {
- mAppContext = Preconditions.checkNotNull(context).getApplicationContext();
- mOnChangedListener = onChangedListener == null ? () -> {} : onChangedListener;
- mSettingsObserver =
- new SettingsObserver(
- mAppContext,
- () -> {
- synchronized (mLock) {
- TcLog.v(
- TAG,
- "Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS is"
- + " updated");
- mParsed = true;
- mOnChangedListener.run();
- }
- });
+ public ActionsModelParamsSupplier(Context context, @Nullable Runnable onChangedListener) {
+ appContext = Preconditions.checkNotNull(context).getApplicationContext();
+ this.onChangedListener = onChangedListener == null ? () -> {} : onChangedListener;
+ settingsObserver =
+ new SettingsObserver(
+ appContext,
+ () -> {
+ synchronized (lock) {
+ TcLog.v(TAG, "Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS is updated");
+ parsed = true;
+ this.onChangedListener.run();
+ }
+ });
+ }
+
+ /**
+ * Returns the parsed actions params or {@link ActionsModelParams#INVALID} if the value is
+ * invalid.
+ */
+ @Override
+ public ActionsModelParams get() {
+ synchronized (lock) {
+ if (parsed) {
+ actionsModelParams = parse();
+ parsed = false;
+ }
+ return actionsModelParams;
+ }
+ }
+
+ private static ActionsModelParams parse() {
+ // String settingStr = Settings.Global.getString(contentResolver,
+ // Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS);
+ // if (TextUtils.isEmpty(settingStr)) {
+ // return ActionsModelParams.INVALID;
+ // }
+ // try {
+ // KeyValueListParser keyValueListParser = new KeyValueListParser(',');
+ // keyValueListParser.setString(settingStr);
+ // int version = keyValueListParser.getInt(KEY_REQUIRED_MODEL_VERSION, -1);
+ // if (version == -1) {
+ // TcLog.w(TAG, "ActionsModelParams.Parse, invalid model version");
+ // return ActionsModelParams.INVALID;
+ // }
+ // String locales = keyValueListParser.getString(KEY_REQUIRED_LOCALES, null);
+ // if (locales == null) {
+ // TcLog.w(TAG, "ActionsModelParams.Parse, invalid locales");
+ // return ActionsModelParams.INVALID;
+ // }
+ // String serializedPreconditionsStr =
+ // keyValueListParser.getString(KEY_SERIALIZED_PRECONDITIONS, null);
+ // if (serializedPreconditionsStr == null) {
+ // TcLog.w(TAG, "ActionsModelParams.Parse, invalid preconditions");
+ // return ActionsModelParams.INVALID;
+ // }
+ // byte[] serializedPreconditions =
+ // Base64.decode(serializedPreconditionsStr, Base64.NO_WRAP);
+ // return new ActionsModelParams(version, locales, serializedPreconditions);
+ // } catch (Throwable t) {
+ // TcLog.e(TAG, "Invalid TEXT_CLASSIFIER_ACTION_MODEL_PARAMS, ignore", t);
+ // }
+ return ActionsModelParams.INVALID;
+ }
+
+ @Override
+ protected void finalize() throws Throwable {
+ try {
+ appContext.getContentResolver().unregisterContentObserver(settingsObserver);
+ } finally {
+ super.finalize();
+ }
+ }
+
+ /** Represents the parsed result. */
+ public static final class ActionsModelParams {
+
+ public static final ActionsModelParams INVALID = new ActionsModelParams(-1, "", new byte[0]);
+
+ /** The required model version to apply {@code serializedPreconditions}. */
+ private final int requiredModelVersion;
+
+ /** The required model locales to apply {@code serializedPreconditions}. */
+ private final String requiredModelLocales;
+
+ /**
+ * The serialized params that will be applied to the model file, if all requirements are met. Do
+ * not modify.
+ */
+ private final byte[] serializedPreconditions;
+
+ public ActionsModelParams(
+ int requiredModelVersion, String requiredModelLocales, byte[] serializedPreconditions) {
+ this.requiredModelVersion = requiredModelVersion;
+ this.requiredModelLocales = Preconditions.checkNotNull(requiredModelLocales);
+ this.serializedPreconditions = Preconditions.checkNotNull(serializedPreconditions);
}
/**
- * Returns the parsed actions params or {@link ActionsModelParams#INVALID} if the value is
- * invalid.
+ * Returns the serialized preconditions. Returns {@code null} if the model in use does not meet
+ * all the requirements listed in the {@code ActionsModelParams} or the params are invalid.
*/
- @Override
- public ActionsModelParams get() {
- synchronized (mLock) {
- if (mParsed) {
- mActionsModelParams = parse(mAppContext.getContentResolver());
- mParsed = false;
- }
- return mActionsModelParams;
- }
+ @Nullable
+ public byte[] getSerializedPreconditions(ModelFileManager.ModelFile modelInUse) {
+ if (this == INVALID) {
+ return null;
+ }
+ if (modelInUse.getVersion() != requiredModelVersion) {
+ TcLog.w(
+ TAG,
+ String.format(
+ "Not applying mSerializedPreconditions, required version=%d, actual=%d",
+ requiredModelVersion, modelInUse.getVersion()));
+ return null;
+ }
+ if (!Objects.equals(modelInUse.getSupportedLocalesStr(), requiredModelLocales)) {
+ TcLog.w(
+ TAG,
+ String.format(
+ "Not applying mSerializedPreconditions, required locales=%s, actual=%s",
+ requiredModelLocales, modelInUse.getSupportedLocalesStr()));
+ return null;
+ }
+ return serializedPreconditions;
}
+ }
- private ActionsModelParams parse(ContentResolver contentResolver) {
- // String settingStr = Settings.Global.getString(contentResolver,
- // Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS);
- // if (TextUtils.isEmpty(settingStr)) {
- // return ActionsModelParams.INVALID;
- // }
- // try {
- // KeyValueListParser keyValueListParser = new KeyValueListParser(',');
- // keyValueListParser.setString(settingStr);
- // int version = keyValueListParser.getInt(KEY_REQUIRED_MODEL_VERSION, -1);
- // if (version == -1) {
- // TcLog.w(TAG, "ActionsModelParams.Parse, invalid model version");
- // return ActionsModelParams.INVALID;
- // }
- // String locales = keyValueListParser.getString(KEY_REQUIRED_LOCALES, null);
- // if (locales == null) {
- // TcLog.w(TAG, "ActionsModelParams.Parse, invalid locales");
- // return ActionsModelParams.INVALID;
- // }
- // String serializedPreconditionsStr =
- // keyValueListParser.getString(KEY_SERIALIZED_PRECONDITIONS, null);
- // if (serializedPreconditionsStr == null) {
- // TcLog.w(TAG, "ActionsModelParams.Parse, invalid preconditions");
- // return ActionsModelParams.INVALID;
- // }
- // byte[] serializedPreconditions =
- // Base64.decode(serializedPreconditionsStr, Base64.NO_WRAP);
- // return new ActionsModelParams(version, locales, serializedPreconditions);
- // } catch (Throwable t) {
- // TcLog.e(TAG, "Invalid TEXT_CLASSIFIER_ACTION_MODEL_PARAMS, ignore", t);
- // }
- return ActionsModelParams.INVALID;
+ private static final class SettingsObserver extends ContentObserver {
+
+ private final WeakReference<Runnable> onChangedListener;
+
+ SettingsObserver(Context appContext, Runnable listener) {
+ super(null);
+ onChangedListener = new WeakReference<>(listener);
+ // appContext.getContentResolver().registerContentObserver(
+ //
+ // Settings.Global.getUriFor(Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS),
+ // false /* notifyForDescendants */,
+ // this);
}
@Override
- protected void finalize() throws Throwable {
- try {
- mAppContext.getContentResolver().unregisterContentObserver(mSettingsObserver);
- } finally {
- super.finalize();
- }
+ public void onChange(boolean selfChange) {
+ if (onChangedListener.get() != null) {
+ onChangedListener.get().run();
+ }
}
-
- /** Represents the parsed result. */
- public static final class ActionsModelParams {
-
- public static final ActionsModelParams INVALID =
- new ActionsModelParams(-1, "", new byte[0]);
-
- /** The required model version to apply {@code mSerializedPreconditions}. */
- private final int mRequiredModelVersion;
-
- /** The required model locales to apply {@code mSerializedPreconditions}. */
- private final String mRequiredModelLocales;
-
- /**
- * The serialized params that will be applied to the model file, if all requirements are
- * met. Do not modify.
- */
- private final byte[] mSerializedPreconditions;
-
- public ActionsModelParams(
- int requiredModelVersion,
- String requiredModelLocales,
- byte[] serializedPreconditions) {
- mRequiredModelVersion = requiredModelVersion;
- mRequiredModelLocales = Preconditions.checkNotNull(requiredModelLocales);
- mSerializedPreconditions = Preconditions.checkNotNull(serializedPreconditions);
- }
-
- /**
- * Returns the serialized preconditions. Returns {@code null} if the the model in use does
- * not meet all the requirements listed in the {@code ActionsModelParams} or the params are
- * invalid.
- */
- @Nullable
- public byte[] getSerializedPreconditions(ModelFileManager.ModelFile modelInUse) {
- if (this == INVALID) {
- return null;
- }
- if (modelInUse.getVersion() != mRequiredModelVersion) {
- TcLog.w(
- TAG,
- String.format(
- "Not applying mSerializedPreconditions, required version=%d,"
- + " actual=%d",
- mRequiredModelVersion, modelInUse.getVersion()));
- return null;
- }
- if (!Objects.equals(modelInUse.getSupportedLocalesStr(), mRequiredModelLocales)) {
- TcLog.w(
- TAG,
- String.format(
- "Not applying mSerializedPreconditions, required locales=%s,"
- + " actual=%s",
- mRequiredModelLocales, modelInUse.getSupportedLocalesStr()));
- return null;
- }
- return mSerializedPreconditions;
- }
- }
-
- private static final class SettingsObserver extends ContentObserver {
-
- private final WeakReference<Runnable> mOnChangedListener;
-
- SettingsObserver(Context appContext, Runnable listener) {
- super(null);
- mOnChangedListener = new WeakReference<>(listener);
- // appContext.getContentResolver().registerContentObserver(
- //
- // Settings.Global.getUriFor(Settings.Global.TEXT_CLASSIFIER_ACTION_MODEL_PARAMS),
- // false /* notifyForDescendants */,
- // this);
- }
-
- @Override
- public void onChange(boolean selfChange) {
- if (mOnChangedListener.get() != null) {
- mOnChangedListener.get().run();
- }
- }
- }
+ }
}
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index 41f0449..55ee402 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,7 +11,7 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier;
@@ -26,17 +26,11 @@
import android.util.Pair;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
-
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-
import com.android.textclassifier.intent.LabeledIntent;
import com.android.textclassifier.intent.TemplateIntentFactory;
import com.android.textclassifier.logging.ResultIdUtils;
-
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.RemoteActionTemplate;
-
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Deque;
@@ -46,182 +40,178 @@
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
+import javax.annotation.Nullable;
/** Helper class for action suggestions. */
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
-public final class ActionsSuggestionsHelper {
- private static final String TAG = "ActionsSuggestions";
- private static final int USER_LOCAL = 0;
- private static final int FIRST_NON_LOCAL_USER = 1;
+final class ActionsSuggestionsHelper {
+ private static final String TAG = "ActionsSuggestions";
+ private static final int USER_LOCAL = 0;
+ private static final int FIRST_NON_LOCAL_USER = 1;
- private ActionsSuggestionsHelper() {}
+ private ActionsSuggestionsHelper() {}
- /**
- * Converts the messages to a list of native messages object that the model can understand.
- *
- * <p>User id encoding - local user is represented as 0, Other users are numbered according to
- * how far before they spoke last time in the conversation. For example, considering this
- * conversation:
- *
- * <ul>
- * <li>User A: xxx
- * <li>Local user: yyy
- * <li>User B: zzz
- * </ul>
- *
- * User A will be encoded as 2, user B will be encoded as 1 and local user will be encoded as 0.
- */
- public static ActionsSuggestionsModel.ConversationMessage[] toNativeMessages(
- List<ConversationActions.Message> messages,
- Function<CharSequence, List<String>> languageDetector) {
- List<ConversationActions.Message> messagesWithText =
- messages.stream()
- .filter(message -> !TextUtils.isEmpty(message.getText()))
- .collect(Collectors.toCollection(ArrayList::new));
- if (messagesWithText.isEmpty()) {
- return new ActionsSuggestionsModel.ConversationMessage[0];
- }
- Deque<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayDeque<>();
- PersonEncoder personEncoder = new PersonEncoder();
- int size = messagesWithText.size();
- for (int i = size - 1; i >= 0; i--) {
- ConversationActions.Message message = messagesWithText.get(i);
- long referenceTime =
- message.getReferenceTime() == null
- ? 0
- : message.getReferenceTime().toInstant().toEpochMilli();
- String timeZone =
- message.getReferenceTime() == null
- ? null
- : message.getReferenceTime().getZone().getId();
- nativeMessages.push(
- new ActionsSuggestionsModel.ConversationMessage(
- personEncoder.encode(message.getAuthor()),
- message.getText().toString(),
- referenceTime,
- timeZone,
- String.join(",", languageDetector.apply(message.getText()))));
- }
- return nativeMessages.toArray(
- new ActionsSuggestionsModel.ConversationMessage[nativeMessages.size()]);
+ /**
+ * Converts the messages to a list of native messages object that the model can understand.
+ *
+ * <p>User id encoding - local user is represented as 0, Other users are numbered according to how
+ * far before they spoke last time in the conversation. For example, considering this
+ * conversation:
+ *
+ * <ul>
+ * <li>User A: xxx
+ * <li>Local user: yyy
+ * <li>User B: zzz
+ * </ul>
+ *
+ * User A will be encoded as 2, user B will be encoded as 1 and local user will be encoded as 0.
+ */
+ public static ActionsSuggestionsModel.ConversationMessage[] toNativeMessages(
+ List<ConversationActions.Message> messages,
+ Function<CharSequence, List<String>> languageDetector) {
+ List<ConversationActions.Message> messagesWithText =
+ messages.stream()
+ .filter(message -> !TextUtils.isEmpty(message.getText()))
+ .collect(Collectors.toCollection(ArrayList::new));
+ if (messagesWithText.isEmpty()) {
+ return new ActionsSuggestionsModel.ConversationMessage[0];
}
-
- /** Returns the result id for logging. */
- public static String createResultId(
- Context context,
- List<ConversationActions.Message> messages,
- int modelVersion,
- List<Locale> modelLocales) {
- final int hash =
- Objects.hash(
- messages.stream().mapToInt(ActionsSuggestionsHelper::hashMessage),
- context.getPackageName(),
- System.currentTimeMillis());
- return ResultIdUtils.createId(modelVersion, modelLocales, hash);
+ Deque<ActionsSuggestionsModel.ConversationMessage> nativeMessages = new ArrayDeque<>();
+ PersonEncoder personEncoder = new PersonEncoder();
+ int size = messagesWithText.size();
+ for (int i = size - 1; i >= 0; i--) {
+ ConversationActions.Message message = messagesWithText.get(i);
+ long referenceTime =
+ message.getReferenceTime() == null
+ ? 0
+ : message.getReferenceTime().toInstant().toEpochMilli();
+ String timeZone =
+ message.getReferenceTime() == null ? null : message.getReferenceTime().getZone().getId();
+ nativeMessages.push(
+ new ActionsSuggestionsModel.ConversationMessage(
+ personEncoder.encode(message.getAuthor()),
+ message.getText().toString(),
+ referenceTime,
+ timeZone,
+ String.join(",", languageDetector.apply(message.getText()))));
}
+ return nativeMessages.toArray(
+ new ActionsSuggestionsModel.ConversationMessage[nativeMessages.size()]);
+ }
- /** Generated labeled intent from an action suggestion and return the resolved result. */
- @Nullable
- public static LabeledIntent.Result createLabeledIntentResult(
- Context context,
- TemplateIntentFactory templateIntentFactory,
- ActionsSuggestionsModel.ActionSuggestion nativeSuggestion) {
- RemoteActionTemplate[] remoteActionTemplates = nativeSuggestion.getRemoteActionTemplates();
- if (remoteActionTemplates == null) {
- TcLog.w(
- TAG,
- "createRemoteAction: Missing template for type "
- + nativeSuggestion.getActionType());
- return null;
- }
- List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
- if (labeledIntents.isEmpty()) {
- return null;
- }
- // Given that we only support implicit intent here, we should expect there is just one
- // intent for each action type.
- LabeledIntent.TitleChooser titleChooser =
- ActionsSuggestionsHelper.createTitleChooser(nativeSuggestion.getActionType());
- return labeledIntents.get(0).resolve(context, titleChooser, null);
+ /** Returns the result id for logging. */
+ public static String createResultId(
+ Context context,
+ List<ConversationActions.Message> messages,
+ int modelVersion,
+ List<Locale> modelLocales) {
+ final int hash =
+ Objects.hash(
+ messages.stream().mapToInt(ActionsSuggestionsHelper::hashMessage),
+ context.getPackageName(),
+ System.currentTimeMillis());
+ return ResultIdUtils.createId(modelVersion, modelLocales, hash);
+ }
+
+ /** Generated labeled intent from an action suggestion and return the resolved result. */
+ @Nullable
+ public static LabeledIntent.Result createLabeledIntentResult(
+ Context context,
+ TemplateIntentFactory templateIntentFactory,
+ ActionsSuggestionsModel.ActionSuggestion nativeSuggestion) {
+ RemoteActionTemplate[] remoteActionTemplates = nativeSuggestion.getRemoteActionTemplates();
+ if (remoteActionTemplates == null) {
+ TcLog.w(
+ TAG, "createRemoteAction: Missing template for type " + nativeSuggestion.getActionType());
+ return null;
}
-
- /** Returns a {@link LabeledIntent.TitleChooser} for conversation actions use case. */
- @Nullable
- public static LabeledIntent.TitleChooser createTitleChooser(String actionType) {
- if (ConversationAction.TYPE_OPEN_URL.equals(actionType)) {
- return (labeledIntent, resolveInfo) -> {
- if (resolveInfo.handleAllWebDataURI) {
- return labeledIntent.titleWithEntity;
- }
- if ("android".equals(resolveInfo.activityInfo.packageName)) {
- return labeledIntent.titleWithEntity;
- }
- return labeledIntent.titleWithoutEntity;
- };
- }
- return null;
+ List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
+ if (labeledIntents.isEmpty()) {
+ return null;
}
+ // Given that we only support implicit intent here, we should expect there is just one
+ // intent for each action type.
+ LabeledIntent.TitleChooser titleChooser =
+ ActionsSuggestionsHelper.createTitleChooser(nativeSuggestion.getActionType());
+ return labeledIntents.get(0).resolve(context, titleChooser, null);
+ }
- /**
- * Returns a list of {@link ConversationAction}s that have 0 duplicates. Two actions are
- * duplicates if they may look the same to users. This function assumes every
- * ConversationActions with a non-null RemoteAction also have a non-null intent in the extras.
- */
- public static List<ConversationAction> removeActionsWithDuplicates(
- List<ConversationAction> conversationActions) {
- // Ideally, we should compare title and icon here, but comparing icon is expensive and thus
- // we use the component name of the target handler as the heuristic.
- Map<Pair<String, String>, Integer> counter = new ArrayMap<>();
- for (ConversationAction conversationAction : conversationActions) {
- Pair<String, String> representation = getRepresentation(conversationAction);
- if (representation == null) {
- continue;
- }
- Integer existingCount = counter.getOrDefault(representation, 0);
- counter.put(representation, existingCount + 1);
+ /** Returns a {@link LabeledIntent.TitleChooser} for conversation actions use case. */
+ @Nullable
+ public static LabeledIntent.TitleChooser createTitleChooser(String actionType) {
+ if (ConversationAction.TYPE_OPEN_URL.equals(actionType)) {
+ return (labeledIntent, resolveInfo) -> {
+ if (resolveInfo.handleAllWebDataURI) {
+ return labeledIntent.titleWithEntity;
}
- List<ConversationAction> result = new ArrayList<>();
- for (ConversationAction conversationAction : conversationActions) {
- Pair<String, String> representation = getRepresentation(conversationAction);
- if (representation == null || counter.getOrDefault(representation, 0) == 1) {
- result.add(conversationAction);
- }
+ if ("android".equals(resolveInfo.activityInfo.packageName)) {
+ return labeledIntent.titleWithEntity;
}
- return result;
+ return labeledIntent.titleWithoutEntity;
+ };
}
+ return null;
+ }
- @Nullable
- private static Pair<String, String> getRepresentation(ConversationAction conversationAction) {
- RemoteAction remoteAction = conversationAction.getAction();
- if (remoteAction == null) {
- return null;
- }
- Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
- ComponentName componentName = actionIntent.getComponent();
- // Action without a component name will be considered as from the same app.
- String packageName = componentName == null ? null : componentName.getPackageName();
- return new Pair<>(conversationAction.getAction().getTitle().toString(), packageName);
+ /**
+ * Returns a list of {@link ConversationAction}s that have 0 duplicates. Two actions are
+ * duplicates if they may look the same to users. This function assumes every ConversationActions
+ * with a non-null RemoteAction also have a non-null intent in the extras.
+ */
+ public static List<ConversationAction> removeActionsWithDuplicates(
+ List<ConversationAction> conversationActions) {
+ // Ideally, we should compare title and icon here, but comparing icon is expensive and thus
+ // we use the component name of the target handler as the heuristic.
+ Map<Pair<String, String>, Integer> counter = new ArrayMap<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ Pair<String, String> representation = getRepresentation(conversationAction);
+ if (representation == null) {
+ continue;
+ }
+ Integer existingCount = counter.getOrDefault(representation, 0);
+ counter.put(representation, existingCount + 1);
}
-
- private static final class PersonEncoder {
- private final Map<Person, Integer> mMapping = new ArrayMap<>();
- private int mNextUserId = FIRST_NON_LOCAL_USER;
-
- private int encode(Person person) {
- if (ConversationActions.Message.PERSON_USER_SELF.equals(person)) {
- return USER_LOCAL;
- }
- Integer result = mMapping.get(person);
- if (result == null) {
- mMapping.put(person, mNextUserId);
- result = mNextUserId;
- mNextUserId++;
- }
- return result;
- }
+ List<ConversationAction> result = new ArrayList<>();
+ for (ConversationAction conversationAction : conversationActions) {
+ Pair<String, String> representation = getRepresentation(conversationAction);
+ if (representation == null || counter.getOrDefault(representation, 0) == 1) {
+ result.add(conversationAction);
+ }
}
+ return result;
+ }
- private static int hashMessage(ConversationActions.Message message) {
- return Objects.hash(message.getAuthor(), message.getText(), message.getReferenceTime());
+ @Nullable
+ private static Pair<String, String> getRepresentation(ConversationAction conversationAction) {
+ RemoteAction remoteAction = conversationAction.getAction();
+ if (remoteAction == null) {
+ return null;
}
+ Intent actionIntent = ExtrasUtils.getActionIntent(conversationAction.getExtras());
+ ComponentName componentName = actionIntent.getComponent();
+ // Action without a component name will be considered as from the same app.
+ String packageName = componentName == null ? null : componentName.getPackageName();
+ return new Pair<>(conversationAction.getAction().getTitle().toString(), packageName);
+ }
+
+ private static final class PersonEncoder {
+ private final Map<Person, Integer> personToUserIdMap = new ArrayMap<>();
+ private int nextUserId = FIRST_NON_LOCAL_USER;
+
+ private int encode(Person person) {
+ if (ConversationActions.Message.PERSON_USER_SELF.equals(person)) {
+ return USER_LOCAL;
+ }
+ Integer result = personToUserIdMap.get(person);
+ if (result == null) {
+ personToUserIdMap.put(person, nextUserId);
+ result = nextUserId;
+ nextUserId++;
+ }
+ return result;
+ }
+ }
+
+ private static int hashMessage(ConversationActions.Message message) {
+ return Objects.hash(message.getAuthor(), message.getText(), message.getReferenceTime());
+ }
}
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index e842247..32cba2f 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,7 +11,7 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier;
@@ -26,160 +26,152 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
-
import com.android.textclassifier.utils.IndentingPrintWriter;
-
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
-
import java.io.FileDescriptor;
import java.io.PrintWriter;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
+/** An implementation of a TextClassifierService. */
public final class DefaultTextClassifierService extends TextClassifierService {
- private static final String TAG = "default_tcs";
+ private static final String TAG = "default_tcs";
- // TODO: Figure out do we need more concurrency.
- private final ListeningExecutorService mNormPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newFixedThreadPool(
- /* nThreads= */ 2,
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-norm-prio-executor")
- .setPriority(Thread.NORM_PRIORITY)
- .build()));
+ // TODO: Figure out do we need more concurrency.
+ private final ListeningExecutorService normPriorityExecutor =
+ MoreExecutors.listeningDecorator(
+ Executors.newFixedThreadPool(
+ /* nThreads= */ 2,
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-norm-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY)
+ .build()));
- private final ListeningExecutorService mLowPriorityExecutor =
- MoreExecutors.listeningDecorator(
- Executors.newSingleThreadExecutor(
- new ThreadFactoryBuilder()
- .setNameFormat("tcs-low-prio-executor")
- .setPriority(Thread.NORM_PRIORITY - 1)
- .build()));
+ private final ListeningExecutorService lowPriorityExecutor =
+ MoreExecutors.listeningDecorator(
+ Executors.newSingleThreadExecutor(
+ new ThreadFactoryBuilder()
+ .setNameFormat("tcs-low-prio-executor")
+ .setPriority(Thread.NORM_PRIORITY - 1)
+ .build()));
- private TextClassifierImpl mTextClassifier;
+ private TextClassifierImpl textClassifier;
- @Override
- public void onCreate() {
- super.onCreate();
- mTextClassifier = new TextClassifierImpl(this, new TextClassificationConstants());
- }
+ @Override
+ public void onCreate() {
+ super.onCreate();
+ textClassifier = new TextClassifierImpl(this, new TextClassificationConstants());
+ }
- @Override
- public void onSuggestSelection(
- TextClassificationSessionId sessionId,
- TextSelection.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextSelection> callback) {
- handleRequestAsync(
- () -> mTextClassifier.suggestSelection(request), callback, cancellationSignal);
- }
+ @Override
+ public void onSuggestSelection(
+ TextClassificationSessionId sessionId,
+ TextSelection.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextSelection> callback) {
+ handleRequestAsync(
+ () -> textClassifier.suggestSelection(request), callback, cancellationSignal);
+ }
- @Override
- public void onClassifyText(
- TextClassificationSessionId sessionId,
- TextClassification.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextClassification> callback) {
- handleRequestAsync(
- () -> mTextClassifier.classifyText(request), callback, cancellationSignal);
- }
+ @Override
+ public void onClassifyText(
+ TextClassificationSessionId sessionId,
+ TextClassification.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextClassification> callback) {
+ handleRequestAsync(() -> textClassifier.classifyText(request), callback, cancellationSignal);
+ }
- @Override
- public void onGenerateLinks(
- TextClassificationSessionId sessionId,
- TextLinks.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextLinks> callback) {
- handleRequestAsync(
- () -> mTextClassifier.generateLinks(request), callback, cancellationSignal);
- }
+ @Override
+ public void onGenerateLinks(
+ TextClassificationSessionId sessionId,
+ TextLinks.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextLinks> callback) {
+ handleRequestAsync(() -> textClassifier.generateLinks(request), callback, cancellationSignal);
+ }
- @Override
- public void onSuggestConversationActions(
- TextClassificationSessionId sessionId,
- ConversationActions.Request request,
- CancellationSignal cancellationSignal,
- Callback<ConversationActions> callback) {
- handleRequestAsync(
- () -> mTextClassifier.suggestConversationActions(request),
- callback,
- cancellationSignal);
- }
+ @Override
+ public void onSuggestConversationActions(
+ TextClassificationSessionId sessionId,
+ ConversationActions.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<ConversationActions> callback) {
+ handleRequestAsync(
+ () -> textClassifier.suggestConversationActions(request), callback, cancellationSignal);
+ }
- @Override
- public void onDetectLanguage(
- TextClassificationSessionId sessionId,
- TextLanguage.Request request,
- CancellationSignal cancellationSignal,
- Callback<TextLanguage> callback) {
- handleRequestAsync(
- () -> mTextClassifier.detectLanguage(request), callback, cancellationSignal);
- }
+ @Override
+ public void onDetectLanguage(
+ TextClassificationSessionId sessionId,
+ TextLanguage.Request request,
+ CancellationSignal cancellationSignal,
+ Callback<TextLanguage> callback) {
+ handleRequestAsync(() -> textClassifier.detectLanguage(request), callback, cancellationSignal);
+ }
- @Override
- public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
- handleEvent(() -> mTextClassifier.onSelectionEvent(event));
- }
+ @Override
+ public void onSelectionEvent(TextClassificationSessionId sessionId, SelectionEvent event) {
+ handleEvent(() -> textClassifier.onSelectionEvent(event));
+ }
- @Override
- public void onTextClassifierEvent(
- TextClassificationSessionId sessionId, TextClassifierEvent event) {
- handleEvent(() -> mTextClassifier.onTextClassifierEvent(sessionId, event));
- }
+ @Override
+ public void onTextClassifierEvent(
+ TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ handleEvent(() -> textClassifier.onTextClassifierEvent(sessionId, event));
+ }
- @Override
- protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
- IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
- mTextClassifier.dump(indentingPrintWriter);
- indentingPrintWriter.flush();
- }
+ @Override
+ protected void dump(FileDescriptor fd, PrintWriter writer, String[] args) {
+ IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
+ textClassifier.dump(indentingPrintWriter);
+ indentingPrintWriter.flush();
+ }
- private <T> void handleRequestAsync(
- Callable<T> callable, Callback<T> callback, CancellationSignal cancellationSignal) {
- ListenableFuture<T> result = mNormPriorityExecutor.submit(callable);
- Futures.addCallback(
- result,
- new FutureCallback<T>() {
- @Override
- public void onSuccess(T result) {
- callback.onSuccess(result);
- }
+ private <T> void handleRequestAsync(
+ Callable<T> callable, Callback<T> callback, CancellationSignal cancellationSignal) {
+ ListenableFuture<T> result = normPriorityExecutor.submit(callable);
+ Futures.addCallback(
+ result,
+ new FutureCallback<T>() {
+ @Override
+ public void onSuccess(T result) {
+ callback.onSuccess(result);
+ }
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "onFailure: ", t);
- callback.onFailure(t.getMessage());
- }
- },
- MoreExecutors.directExecutor());
- cancellationSignal.setOnCancelListener(
- () -> result.cancel(/* mayInterruptIfRunning= */ true));
- }
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "onFailure: ", t);
+ callback.onFailure(t.getMessage());
+ }
+ },
+ MoreExecutors.directExecutor());
+ cancellationSignal.setOnCancelListener(() -> result.cancel(/* mayInterruptIfRunning= */ true));
+ }
- private void handleEvent(Runnable runnable) {
- ListenableFuture<Void> result =
- mLowPriorityExecutor.submit(
- () -> {
- runnable.run();
- return null;
- });
- Futures.addCallback(
- result,
- new FutureCallback<Void>() {
- @Override
- public void onSuccess(Void result) {}
+ private void handleEvent(Runnable runnable) {
+ ListenableFuture<Void> result =
+ lowPriorityExecutor.submit(
+ () -> {
+ runnable.run();
+ return null;
+ });
+ Futures.addCallback(
+ result,
+ new FutureCallback<Void>() {
+ @Override
+ public void onSuccess(Void result) {}
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "onFailure: ", t);
- }
- },
- MoreExecutors.directExecutor());
- }
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "onFailure: ", t);
+ }
+ },
+ MoreExecutors.directExecutor());
+ }
}
diff --git a/java/src/com/android/textclassifier/Entity.java b/java/src/com/android/textclassifier/Entity.java
index 8a29923..6410a3e 100644
--- a/java/src/com/android/textclassifier/Entity.java
+++ b/java/src/com/android/textclassifier/Entity.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,62 +17,65 @@
package com.android.textclassifier;
import androidx.annotation.FloatRange;
-
import com.google.common.base.Objects;
import com.google.common.base.Preconditions;
/** A representation of an identified entity with the confidence score */
public final class Entity implements Comparable<Entity> {
- private final String mEntityType;
- private float mScore;
+ private final String entityType;
+ private final float score;
- public Entity(String entityType, float score) {
- mEntityType = Preconditions.checkNotNull(entityType);
- mScore = score;
- }
+ public Entity(String entityType, float score) {
+ this.entityType = Preconditions.checkNotNull(entityType);
+ this.score = score;
+ }
- public String getEntityType() {
- return mEntityType;
- }
+ public String getEntityType() {
+ return entityType;
+ }
- /**
- * Returns the confidence score of the entity, which ranged from 0.0 (low confidence) to 1.0
- * (high confidence).
- */
- @FloatRange(from = 0.0, to = 1.0)
- public Float getScore() {
- return mScore;
- }
+ /**
+ * Returns the confidence score of the entity, which ranged from 0.0 (low confidence) to 1.0 (high
+ * confidence).
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public Float getScore() {
+ return score;
+ }
- @Override
- public int hashCode() {
- return Objects.hashCode(mEntityType, mScore);
- }
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(entityType, score);
+ }
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- Entity entity = (Entity) o;
- return Float.compare(entity.mScore, mScore) == 0
- && java.util.Objects.equals(mEntityType, entity.mEntityType);
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
}
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Entity entity = (Entity) o;
+ return Float.compare(entity.score, score) == 0
+ && java.util.Objects.equals(entityType, entity.entityType);
+ }
- @Override
- public String toString() {
- return "Entity{" + mEntityType + ": " + mScore + "}";
- }
+ @Override
+ public String toString() {
+ return "Entity{" + entityType + ": " + score + "}";
+ }
- @Override
- public int compareTo(Entity entity) {
- // This method is implemented for sorting Entity. Sort the entities by the confidence score
- // in descending order firstly. If the scores are the same, then sort them by the entity
- // type in ascending order.
- int result = Float.compare(entity.getScore(), mScore);
- if (result == 0) {
- return mEntityType.compareTo(entity.getEntityType());
- }
- return result;
+ @Override
+ public int compareTo(Entity entity) {
+ // This method is implemented for sorting Entity. Sort the entities by the confidence score
+ // in descending order firstly. If the scores are the same, then sort them by the entity
+ // type in ascending order.
+ int result = Float.compare(entity.getScore(), score);
+ if (result == 0) {
+ return entityType.compareTo(entity.getEntityType());
}
+ return result;
+ }
}
diff --git a/java/src/com/android/textclassifier/EntityConfidence.java b/java/src/com/android/textclassifier/EntityConfidence.java
index d3d26e4..ef8ff05 100644
--- a/java/src/com/android/textclassifier/EntityConfidence.java
+++ b/java/src/com/android/textclassifier/EntityConfidence.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,10 +17,8 @@
package com.android.textclassifier;
import androidx.annotation.FloatRange;
-import androidx.annotation.NonNull;
import androidx.collection.ArrayMap;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -29,65 +27,64 @@
/** Helper object for setting and getting entity scores for classified text. */
final class EntityConfidence {
- static final EntityConfidence EMPTY = new EntityConfidence(Collections.emptyMap());
+ static final EntityConfidence EMPTY = new EntityConfidence(Collections.emptyMap());
- private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>();
- private final ArrayList<String> mSortedEntities = new ArrayList<>();
+ private final ArrayMap<String, Float> entityConfidence = new ArrayMap<>();
+ private final ArrayList<String> sortedEntities = new ArrayList<>();
- /**
- * Constructs an EntityConfidence from a map of entity to confidence.
- *
- * <p>Map entries that have 0 confidence are removed, and values greater than 1 are clamped to
- * 1.
- *
- * @param source a map from entity to a confidence value in the range 0 (low confidence) to 1
- * (high confidence).
- */
- EntityConfidence(@NonNull Map<String, Float> source) {
- Preconditions.checkNotNull(source);
+ /**
+ * Constructs an EntityConfidence from a map of entity to confidence.
+ *
+ * <p>Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
+ *
+ * @param source a map from entity to a confidence value in the range 0 (low confidence) to 1
+ * (high confidence).
+ */
+ EntityConfidence(Map<String, Float> source) {
+ Preconditions.checkNotNull(source);
- // Prune non-existent entities and clamp to 1.
- mEntityConfidence.ensureCapacity(source.size());
- for (Map.Entry<String, Float> it : source.entrySet()) {
- if (it.getValue() <= 0) continue;
- mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
- }
- resetSortedEntitiesFromMap();
+ // Prune non-existent entities and clamp to 1.
+ entityConfidence.ensureCapacity(source.size());
+ for (Map.Entry<String, Float> it : source.entrySet()) {
+ if (it.getValue() <= 0) {
+ continue;
+ }
+ entityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
}
+ resetSortedEntitiesFromMap();
+ }
- /**
- * Returns an immutable list of entities found in the classified text ordered from high
- * confidence to low confidence.
- */
- @NonNull
- public List<String> getEntities() {
- return Collections.unmodifiableList(mSortedEntities);
- }
+ /**
+ * Returns an immutable list of entities found in the classified text ordered from high confidence
+ * to low confidence.
+ */
+ public List<String> getEntities() {
+ return Collections.unmodifiableList(sortedEntities);
+ }
- /**
- * Returns the confidence score for the specified entity. The value ranges from 0 (low
- * confidence) to 1 (high confidence). 0 indicates that the entity was not found for the
- * classified text.
- */
- @FloatRange(from = 0.0, to = 1.0)
- public float getConfidenceScore(String entity) {
- return mEntityConfidence.getOrDefault(entity, 0f);
- }
+ /**
+ * Returns the confidence score for the specified entity. The value ranges from 0 (low confidence)
+ * to 1 (high confidence). 0 indicates that the entity was not found for the classified text.
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public float getConfidenceScore(String entity) {
+ return entityConfidence.getOrDefault(entity, 0f);
+ }
- @Override
- public String toString() {
- return mEntityConfidence.toString();
- }
+ @Override
+ public String toString() {
+ return entityConfidence.toString();
+ }
- private void resetSortedEntitiesFromMap() {
- mSortedEntities.clear();
- mSortedEntities.ensureCapacity(mEntityConfidence.size());
- mSortedEntities.addAll(mEntityConfidence.keySet());
- mSortedEntities.sort(
- (e1, e2) -> {
- float score1 = mEntityConfidence.get(e1);
- float score2 = mEntityConfidence.get(e2);
- return Float.compare(score2, score1);
- });
- }
+ private void resetSortedEntitiesFromMap() {
+ sortedEntities.clear();
+ sortedEntities.ensureCapacity(entityConfidence.size());
+ sortedEntities.addAll(entityConfidence.keySet());
+ sortedEntities.sort(
+ (e1, e2) -> {
+ float score1 = entityConfidence.get(e1);
+ float score2 = entityConfidence.get(e2);
+ return Float.compare(score2, score1);
+ });
+ }
}
diff --git a/java/src/com/android/textclassifier/ExtrasUtils.java b/java/src/com/android/textclassifier/ExtrasUtils.java
index 8d18e19..2039d76 100644
--- a/java/src/com/android/textclassifier/ExtrasUtils.java
+++ b/java/src/com/android/textclassifier/ExtrasUtils.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,295 +23,287 @@
import android.view.textclassifier.TextClassification;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextLinks;
-
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-
import com.google.android.textclassifier.AnnotatorModel;
-
+import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
+import javax.annotation.Nullable;
/** Utility class for inserting and retrieving data in TextClassifier request/response extras. */
// TODO: Make this a TestApi for CTS testing.
public final class ExtrasUtils {
- // Keys for response objects.
- private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
- private static final String ENTITIES_EXTRAS = "entities-extras";
- private static final String ACTION_INTENT = "action-intent";
- private static final String ACTIONS_INTENTS = "actions-intents";
- private static final String FOREIGN_LANGUAGE = "foreign-language";
- private static final String ENTITY_TYPE = "entity-type";
- private static final String SCORE = "score";
- private static final String MODEL_VERSION = "model-version";
- private static final String MODEL_NAME = "model-name";
- private static final String TEXT_LANGUAGES = "text-languages";
- private static final String ENTITIES = "entities";
+ // Keys for response objects.
+ private static final String SERIALIZED_ENTITIES_DATA = "serialized-entities-data";
+ private static final String ENTITIES_EXTRAS = "entities-extras";
+ private static final String ACTION_INTENT = "action-intent";
+ private static final String ACTIONS_INTENTS = "actions-intents";
+ private static final String FOREIGN_LANGUAGE = "foreign-language";
+ private static final String ENTITY_TYPE = "entity-type";
+ private static final String SCORE = "score";
+ private static final String MODEL_VERSION = "model-version";
+ private static final String MODEL_NAME = "model-name";
+ private static final String TEXT_LANGUAGES = "text-languages";
+ private static final String ENTITIES = "entities";
- // Keys for request objects.
- private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
- "is-serialized-entity-data-enabled";
+ // Keys for request objects.
+ private static final String IS_SERIALIZED_ENTITY_DATA_ENABLED =
+ "is-serialized-entity-data-enabled";
- private ExtrasUtils() {}
+ private ExtrasUtils() {}
- /** Bundles and returns foreign language detection information for TextClassifier responses. */
- static Bundle createForeignLanguageExtra(String language, float score, int modelVersion) {
- final Bundle bundle = new Bundle();
- bundle.putString(ENTITY_TYPE, language);
- bundle.putFloat(SCORE, score);
- bundle.putInt(MODEL_VERSION, modelVersion);
- bundle.putString(MODEL_NAME, "langId_v" + modelVersion);
- return bundle;
+ /** Bundles and returns foreign language detection information for TextClassifier responses. */
+ static Bundle createForeignLanguageExtra(String language, float score, int modelVersion) {
+ final Bundle bundle = new Bundle();
+ bundle.putString(ENTITY_TYPE, language);
+ bundle.putFloat(SCORE, score);
+ bundle.putInt(MODEL_VERSION, modelVersion);
+ bundle.putString(MODEL_NAME, "langId_v" + modelVersion);
+ return bundle;
+ }
+
+ /**
+ * Stores {@code extra} as foreign language information in TextClassifier response object's extras
+ * {@code container}.
+ *
+ * @see #getForeignLanguageExtra(TextClassification)
+ */
+ static void putForeignLanguageExtra(Bundle container, Bundle extra) {
+ container.putParcelable(FOREIGN_LANGUAGE, extra);
+ }
+
+ /**
+ * Returns foreign language detection information contained in the TextClassification object.
+ * responses.
+ *
+ * @see #putForeignLanguageExtra(Bundle, Bundle)
+ */
+ @Nullable
+ @VisibleForTesting
+ public static Bundle getForeignLanguageExtra(@Nullable TextClassification classification) {
+ if (classification == null) {
+ return null;
}
+ return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
+ }
- /**
- * Stores {@code extra} as foreign language information in TextClassifier response object's
- * extras {@code container}.
- *
- * @see #getForeignLanguageExtra(TextClassification)
- */
- static void putForeignLanguageExtra(Bundle container, Bundle extra) {
- container.putParcelable(FOREIGN_LANGUAGE, extra);
+ /** @see #getTopLanguage(Intent) */
+ static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
+ final int maxSize = Math.min(3, languageScores.getEntities().size());
+ final String[] languages =
+ languageScores.getEntities().subList(0, maxSize).toArray(new String[0]);
+ final float[] scores = new float[languages.length];
+ for (int i = 0; i < languages.length; i++) {
+ scores[i] = languageScores.getConfidenceScore(languages[i]);
}
+ container.putStringArray(ENTITY_TYPE, languages);
+ container.putFloatArray(SCORE, scores);
+ }
- /**
- * Returns foreign language detection information contained in the TextClassification object.
- * responses.
- *
- * @see #putForeignLanguageExtra(Bundle, Bundle)
- */
- @Nullable
- @VisibleForTesting
- public static Bundle getForeignLanguageExtra(@Nullable TextClassification classification) {
- if (classification == null) {
- return null;
+ /** See {@link #putTopLanguageScores(Bundle, EntityConfidence)}. */
+ @Nullable
+ @VisibleForTesting
+ public static ULocale getTopLanguage(@Nullable Intent intent) {
+ if (intent == null) {
+ return null;
+ }
+ final Bundle tcBundle = intent.getBundleExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER);
+ if (tcBundle == null) {
+ return null;
+ }
+ final Bundle textLanguagesExtra = tcBundle.getBundle(TEXT_LANGUAGES);
+ if (textLanguagesExtra == null) {
+ return null;
+ }
+ final String[] languages = textLanguagesExtra.getStringArray(ENTITY_TYPE);
+ final float[] scores = textLanguagesExtra.getFloatArray(SCORE);
+ if (languages == null
+ || scores == null
+ || languages.length == 0
+ || languages.length != scores.length) {
+ return null;
+ }
+ int highestScoringIndex = 0;
+ for (int i = 1; i < languages.length; i++) {
+ if (scores[highestScoringIndex] < scores[i]) {
+ highestScoringIndex = i;
+ }
+ }
+ return ULocale.forLanguageTag(languages[highestScoringIndex]);
+ }
+
+ public static void putTextLanguagesExtra(Bundle container, Bundle extra) {
+ container.putBundle(TEXT_LANGUAGES, extra);
+ }
+
+ /**
+ * Stores {@code actionsIntents} information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ static void putActionsIntents(Bundle container, ArrayList<Intent> actionsIntents) {
+ container.putParcelableArrayList(ACTIONS_INTENTS, actionsIntents);
+ }
+
+ /**
+ * Stores {@code actionIntent} information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ public static void putActionIntent(Bundle container, @Nullable Intent actionIntent) {
+ container.putParcelable(ACTION_INTENT, actionIntent);
+ }
+
+ /** Returns {@code actionIntent} information contained in a TextClassifier response object. */
+ @Nullable
+ public static Intent getActionIntent(Bundle container) {
+ return container.getParcelable(ACTION_INTENT);
+ }
+
+ /**
+ * Stores serialized entity data information in TextClassifier response object's extras {@code
+ * container}.
+ */
+ public static void putSerializedEntityData(
+ Bundle container, @Nullable byte[] serializedEntityData) {
+ container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
+ }
+
+ /** Returns serialized entity data information contained in a TextClassifier response object. */
+ @Nullable
+ public static byte[] getSerializedEntityData(Bundle container) {
+ return container.getByteArray(SERIALIZED_ENTITIES_DATA);
+ }
+
+ /**
+ * Stores {@code entities} information in TextClassifier response object's extras {@code
+ * container}.
+ *
+ * @see {@link #getCopyText(Bundle)}
+ */
+ public static void putEntitiesExtras(Bundle container, @Nullable Bundle entitiesExtras) {
+ container.putParcelable(ENTITIES_EXTRAS, entitiesExtras);
+ }
+
+ /**
+ * Returns {@code entities} information contained in a TextClassifier response object.
+ *
+ * @see {@link #putEntitiesExtras(Bundle, Bundle)}
+ */
+ @Nullable
+ public static String getCopyText(Bundle container) {
+ Bundle entitiesExtras = container.getParcelable(ENTITIES_EXTRAS);
+ if (entitiesExtras == null) {
+ return null;
+ }
+ return entitiesExtras.getString("text");
+ }
+
+ /** Returns {@code actionIntents} information contained in the TextClassification object. */
+ @Nullable
+ public static ArrayList<Intent> getActionsIntents(@Nullable TextClassification classification) {
+ if (classification == null) {
+ return null;
+ }
+ return classification.getExtras().getParcelableArrayList(ACTIONS_INTENTS);
+ }
+
+ /**
+ * Returns the first action found in the {@code classification} object with an intent action
+ * string, {@code intentAction}.
+ */
+ @Nullable
+ @VisibleForTesting
+ public static RemoteAction findAction(
+ @Nullable TextClassification classification, @Nullable String intentAction) {
+ if (classification == null || intentAction == null) {
+ return null;
+ }
+ final ArrayList<Intent> actionIntents = getActionsIntents(classification);
+ if (actionIntents != null) {
+ final int size = actionIntents.size();
+ for (int i = 0; i < size; i++) {
+ final Intent intent = actionIntents.get(i);
+ if (intent != null && intentAction.equals(intent.getAction())) {
+ return classification.getActions().get(i);
}
- return classification.getExtras().getBundle(FOREIGN_LANGUAGE);
+ }
}
+ return null;
+ }
- /** @see #getTopLanguage(Intent) */
- @VisibleForTesting
- static void putTopLanguageScores(Bundle container, EntityConfidence languageScores) {
- final int maxSize = Math.min(3, languageScores.getEntities().size());
- final String[] languages =
- languageScores.getEntities().subList(0, maxSize).toArray(new String[0]);
- final float[] scores = new float[languages.length];
- for (int i = 0; i < languages.length; i++) {
- scores[i] = languageScores.getConfidenceScore(languages[i]);
- }
- container.putStringArray(ENTITY_TYPE, languages);
- container.putFloatArray(SCORE, scores);
- }
+ /** Returns the first "translate" action found in the {@code classification} object. */
+ @Nullable
+ @VisibleForTesting
+ public static RemoteAction findTranslateAction(@Nullable TextClassification classification) {
+ return findAction(classification, Intent.ACTION_TRANSLATE);
+ }
- /** @see #putTopLanguageScores(Bundle, EntityConfidence) */
- @Nullable
- @VisibleForTesting
- public static ULocale getTopLanguage(@Nullable Intent intent) {
- if (intent == null) {
- return null;
- }
- final Bundle tcBundle = intent.getBundleExtra(TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER);
- if (tcBundle == null) {
- return null;
- }
- final Bundle textLanguagesExtra = tcBundle.getBundle(TEXT_LANGUAGES);
- if (textLanguagesExtra == null) {
- return null;
- }
- final String[] languages = textLanguagesExtra.getStringArray(ENTITY_TYPE);
- final float[] scores = textLanguagesExtra.getFloatArray(SCORE);
- if (languages == null
- || scores == null
- || languages.length == 0
- || languages.length != scores.length) {
- return null;
- }
- int highestScoringIndex = 0;
- for (int i = 1; i < languages.length; i++) {
- if (scores[highestScoringIndex] < scores[i]) {
- highestScoringIndex = i;
- }
- }
- return ULocale.forLanguageTag(languages[highestScoringIndex]);
+ /** Returns the entity type contained in the {@code extra}. */
+ @Nullable
+ @VisibleForTesting
+ public static String getEntityType(@Nullable Bundle extra) {
+ if (extra == null) {
+ return null;
}
+ return extra.getString(ENTITY_TYPE);
+ }
- public static void putTextLanguagesExtra(Bundle container, Bundle extra) {
- container.putBundle(TEXT_LANGUAGES, extra);
+ /** Returns the score contained in the {@code extra}. */
+ @VisibleForTesting
+ public static float getScore(Bundle extra) {
+ final int defaultValue = -1;
+ if (extra == null) {
+ return defaultValue;
}
+ return extra.getFloat(SCORE, defaultValue);
+ }
- /**
- * Stores {@code actionIntents} information in TextClassifier response object's extras {@code
- * container}.
- */
- static void putActionsIntents(Bundle container, ArrayList<Intent> actionsIntents) {
- container.putParcelableArrayList(ACTIONS_INTENTS, actionsIntents);
+ /** Returns the model name contained in the {@code extra}. */
+ @Nullable
+ public static String getModelName(@Nullable Bundle extra) {
+ if (extra == null) {
+ return null;
}
+ return extra.getString(MODEL_NAME);
+ }
- /**
- * Stores {@code actionIntents} information in TextClassifier response object's extras {@code
- * container}.
- */
- public static void putActionIntent(Bundle container, @Nullable Intent actionIntent) {
- container.putParcelable(ACTION_INTENT, actionIntent);
+ /** Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}. */
+ public static void putEntities(
+ Bundle container, @Nullable AnnotatorModel.ClassificationResult[] classifications) {
+ if (classifications == null || classifications.length == 0) {
+ return;
}
+ ArrayList<Bundle> entitiesBundle = new ArrayList<>();
+ for (AnnotatorModel.ClassificationResult classification : classifications) {
+ if (classification == null) {
+ continue;
+ }
+ Bundle entityBundle = new Bundle();
+ entityBundle.putString(ENTITY_TYPE, classification.getCollection());
+ entityBundle.putByteArray(SERIALIZED_ENTITIES_DATA, classification.getSerializedEntityData());
+ entitiesBundle.add(entityBundle);
+ }
+ if (!entitiesBundle.isEmpty()) {
+ container.putParcelableArrayList(ENTITIES, entitiesBundle);
+ }
+ }
- /** Returns {@code actionIntent} information contained in a TextClassifier response object. */
- @Nullable
- public static Intent getActionIntent(Bundle container) {
- return container.getParcelable(ACTION_INTENT);
- }
+ /** Returns a list of entities contained in the {@code extra}. */
+ @Nullable
+ @VisibleForTesting
+ public static List<Bundle> getEntities(Bundle container) {
+ return container.getParcelableArrayList(ENTITIES);
+ }
- /**
- * Stores serialized entity data information in TextClassifier response object's extras {@code
- * container}.
- */
- public static void putSerializedEntityData(
- Bundle container, @Nullable byte[] serializedEntityData) {
- container.putByteArray(SERIALIZED_ENTITIES_DATA, serializedEntityData);
- }
+ /** Whether the annotator should populate serialized entity data into the result object. */
+ public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
+ return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
+ }
- /** Returns serialized entity data information contained in a TextClassifier response object. */
- @Nullable
- public static byte[] getSerializedEntityData(Bundle container) {
- return container.getByteArray(SERIALIZED_ENTITIES_DATA);
- }
-
- /**
- * Stores {@code entities} information in TextClassifier response object's extras {@code
- * container}.
- *
- * @see {@link #getCopyText(Bundle)}
- */
- public static void putEntitiesExtras(Bundle container, @Nullable Bundle entitiesExtras) {
- container.putParcelable(ENTITIES_EXTRAS, entitiesExtras);
- }
-
- /**
- * Returns {@code entities} information contained in a TextClassifier response object.
- *
- * @see {@link #putEntitiesExtras(Bundle, Bundle)}
- */
- @Nullable
- public static String getCopyText(Bundle container) {
- Bundle entitiesExtras = container.getParcelable(ENTITIES_EXTRAS);
- if (entitiesExtras == null) {
- return null;
- }
- return entitiesExtras.getString("text");
- }
-
- /** Returns {@code actionIntents} information contained in the TextClassification object. */
- @Nullable
- public static ArrayList<Intent> getActionsIntents(@Nullable TextClassification classification) {
- if (classification == null) {
- return null;
- }
- return classification.getExtras().getParcelableArrayList(ACTIONS_INTENTS);
- }
-
- /**
- * Returns the first action found in the {@code classification} object with an intent action
- * string, {@code intentAction}.
- */
- @Nullable
- @VisibleForTesting
- public static RemoteAction findAction(
- @Nullable TextClassification classification, @Nullable String intentAction) {
- if (classification == null || intentAction == null) {
- return null;
- }
- final ArrayList<Intent> actionIntents = getActionsIntents(classification);
- if (actionIntents != null) {
- final int size = actionIntents.size();
- for (int i = 0; i < size; i++) {
- final Intent intent = actionIntents.get(i);
- if (intent != null && intentAction.equals(intent.getAction())) {
- return classification.getActions().get(i);
- }
- }
- }
- return null;
- }
-
- /** Returns the first "translate" action found in the {@code classification} object. */
- @Nullable
- @VisibleForTesting
- public static RemoteAction findTranslateAction(@Nullable TextClassification classification) {
- return findAction(classification, Intent.ACTION_TRANSLATE);
- }
-
- /** Returns the entity type contained in the {@code extra}. */
- @Nullable
- @VisibleForTesting
- public static String getEntityType(@Nullable Bundle extra) {
- if (extra == null) {
- return null;
- }
- return extra.getString(ENTITY_TYPE);
- }
-
- /** Returns the score contained in the {@code extra}. */
- @VisibleForTesting
- public static float getScore(Bundle extra) {
- final int defaultValue = -1;
- if (extra == null) {
- return defaultValue;
- }
- return extra.getFloat(SCORE, defaultValue);
- }
-
- /** Returns the model name contained in the {@code extra}. */
- @Nullable
- public static String getModelName(@Nullable Bundle extra) {
- if (extra == null) {
- return null;
- }
- return extra.getString(MODEL_NAME);
- }
-
- /**
- * Stores the entities from {@link AnnotatorModel.ClassificationResult} in {@code container}.
- */
- public static void putEntities(
- Bundle container, @Nullable AnnotatorModel.ClassificationResult[] classifications) {
- if (classifications == null || classifications.length == 0) {
- return;
- }
- ArrayList<Bundle> entitiesBundle = new ArrayList<>();
- for (AnnotatorModel.ClassificationResult classification : classifications) {
- if (classification == null) {
- continue;
- }
- Bundle entityBundle = new Bundle();
- entityBundle.putString(ENTITY_TYPE, classification.getCollection());
- entityBundle.putByteArray(
- SERIALIZED_ENTITIES_DATA, classification.getSerializedEntityData());
- entitiesBundle.add(entityBundle);
- }
- if (!entitiesBundle.isEmpty()) {
- container.putParcelableArrayList(ENTITIES, entitiesBundle);
- }
- }
-
- /** Returns a list of entities contained in the {@code extra}. */
- @Nullable
- @VisibleForTesting
- public static List<Bundle> getEntities(Bundle container) {
- return container.getParcelableArrayList(ENTITIES);
- }
-
- /** Whether the annotator should populate serialized entity data into the result object. */
- public static boolean isSerializedEntityDataEnabled(TextLinks.Request request) {
- return request.getExtras().getBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED);
- }
-
- /**
- * To indicate whether the annotator should populate serialized entity data in the result
- * object.
- */
- @VisibleForTesting
- public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
- bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
- }
+ /**
+ * To indicate whether the annotator should populate serialized entity data in the result object.
+ */
+ @VisibleForTesting
+ public static void putIsSerializedEntityDataEnabled(Bundle bundle, boolean isEnabled) {
+ bundle.putBoolean(IS_SERIALIZED_ENTITY_DATA_ENABLED, isEnabled);
+ }
}
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
index 69ae40e..1d8bf3e 100644
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -13,16 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
import android.text.TextUtils;
-
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
+import com.google.common.base.Splitter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
@@ -36,263 +34,262 @@
import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
+import javax.annotation.Nullable;
/** Manages model files that are listed by the model files supplier. */
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
-public final class ModelFileManager {
- private static final String TAG = "ModelFileManager";
+final class ModelFileManager {
+ private static final String TAG = "ModelFileManager";
- private final Object mLock = new Object();
- private final Supplier<List<ModelFile>> mModelFileSupplier;
+ private final Object lock = new Object();
+ private final Supplier<List<ModelFile>> modelFileSupplier;
- private List<ModelFile> mModelFiles;
+ private List<ModelFile> modelFiles;
- public ModelFileManager(Supplier<List<ModelFile>> modelFileSupplier) {
- mModelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
+ public ModelFileManager(Supplier<List<ModelFile>> modelFileSupplier) {
+ this.modelFileSupplier = Preconditions.checkNotNull(modelFileSupplier);
+ }
+
+ /**
+ * Returns an unmodifiable list of model files listed by the given model files supplier.
+ *
+ * <p>The result is cached.
+ */
+ public List<ModelFile> listModelFiles() {
+ synchronized (lock) {
+ if (modelFiles == null) {
+ modelFiles = Collections.unmodifiableList(modelFileSupplier.get());
+ }
+ return modelFiles;
+ }
+ }
+
+ /**
+ * Returns the best model file for the given localelist, {@code null} if nothing is found.
+ *
+ * @param localeList the required locales, use {@code null} if there is no preference.
+ */
+ public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
+ final String languages =
+ localeList == null || localeList.isEmpty()
+ ? LocaleList.getDefault().toLanguageTags()
+ : localeList.toLanguageTags();
+ final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
+
+ ModelFile bestModel = null;
+ for (ModelFile model : listModelFiles()) {
+ if (model.isAnyLanguageSupported(languageRangeList)) {
+ if (model.isPreferredTo(bestModel)) {
+ bestModel = model;
+ }
+ }
+ }
+ return bestModel;
+ }
+
+ /** Default implementation of the model file supplier. */
+ public static final class ModelFileSupplierImpl implements Supplier<List<ModelFile>> {
+ private final File updatedModelFile;
+ private final File factoryModelDir;
+ private final Pattern modelFilenamePattern;
+ private final Function<Integer, Integer> versionSupplier;
+ private final Function<Integer, String> supportedLocalesSupplier;
+
+ public ModelFileSupplierImpl(
+ File factoryModelDir,
+ String factoryModelFileNameRegex,
+ File updatedModelFile,
+ Function<Integer, Integer> versionSupplier,
+ Function<Integer, String> supportedLocalesSupplier) {
+ this.updatedModelFile = Preconditions.checkNotNull(updatedModelFile);
+ this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
+ modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(factoryModelFileNameRegex));
+ this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
+ this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
}
- /**
- * Returns an unmodifiable list of model files listed by the given model files supplier.
- *
- * <p>The result is cached.
- */
- public List<ModelFile> listModelFiles() {
- synchronized (mLock) {
- if (mModelFiles == null) {
- mModelFiles = Collections.unmodifiableList(mModelFileSupplier.get());
- }
- return mModelFiles;
+ @Override
+ public List<ModelFile> get() {
+ final List<ModelFile> modelFiles = new ArrayList<>();
+ // The update model has the highest precedence.
+ if (updatedModelFile.exists()) {
+ final ModelFile updatedModel = createModelFile(updatedModelFile);
+ if (updatedModel != null) {
+ modelFiles.add(updatedModel);
}
+ }
+ // Factory models should never have overlapping locales, so the order doesn't matter.
+ if (factoryModelDir.exists() && factoryModelDir.isDirectory()) {
+ final File[] files = factoryModelDir.listFiles();
+ for (File file : files) {
+ final Matcher matcher = modelFilenamePattern.matcher(file.getName());
+ if (matcher.matches() && file.isFile()) {
+ final ModelFile model = createModelFile(file);
+ if (model != null) {
+ modelFiles.add(model);
+ }
+ }
+ }
+ }
+ return modelFiles;
}
- /**
- * Returns the best model file for the given localelist, {@code null} if nothing is found.
- *
- * @param localeList the required locales, use {@code null} if there is no preference.
- */
- public ModelFile findBestModelFile(@Nullable LocaleList localeList) {
- final String languages =
- localeList == null || localeList.isEmpty()
- ? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags();
- final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
-
- ModelFile bestModel = null;
- for (ModelFile model : listModelFiles()) {
- if (model.isAnyLanguageSupported(languageRangeList)) {
- if (model.isPreferredTo(bestModel)) {
- bestModel = model;
- }
- }
+ /** Returns null if the path did not point to a compatible model. */
+ @Nullable
+ private ModelFile createModelFile(File file) {
+ if (!file.exists()) {
+ return null;
+ }
+ ParcelFileDescriptor modelFd = null;
+ try {
+ modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ if (modelFd == null) {
+ return null;
}
- return bestModel;
+ final int modelFdInt = modelFd.getFd();
+ final int version = versionSupplier.apply(modelFdInt);
+ final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt);
+ if (supportedLocalesStr.isEmpty()) {
+ TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
+ return null;
+ }
+ final List<Locale> supportedLocales = new ArrayList<>();
+ for (String langTag : Splitter.on(',').split(supportedLocalesStr)) {
+ supportedLocales.add(Locale.forLanguageTag(langTag));
+ }
+ return new ModelFile(
+ file,
+ version,
+ supportedLocales,
+ supportedLocalesStr,
+ ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
+ } catch (FileNotFoundException e) {
+ TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
+ return null;
+ } finally {
+ maybeCloseAndLogError(modelFd);
+ }
}
- /** Default implementation of the model file supplier. */
- public static final class ModelFileSupplierImpl implements Supplier<List<ModelFile>> {
- private final File mUpdatedModelFile;
- private final File mFactoryModelDir;
- private final Pattern mModelFilenamePattern;
- private final Function<Integer, Integer> mVersionSupplier;
- private final Function<Integer, String> mSupportedLocalesSupplier;
+ /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
+ }
+ try {
+ fd.close();
+ } catch (IOException e) {
+ TcLog.e(TAG, "Error closing file.", e);
+ }
+ }
+ }
- public ModelFileSupplierImpl(
- File factoryModelDir,
- String factoryModelFileNameRegex,
- File updatedModelFile,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- mUpdatedModelFile = Preconditions.checkNotNull(updatedModelFile);
- mFactoryModelDir = Preconditions.checkNotNull(factoryModelDir);
- mModelFilenamePattern =
- Pattern.compile(Preconditions.checkNotNull(factoryModelFileNameRegex));
- mVersionSupplier = Preconditions.checkNotNull(versionSupplier);
- mSupportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
- }
+ /** Describes TextClassifier model files on disk. */
+ public static final class ModelFile {
+ public static final String LANGUAGE_INDEPENDENT = "*";
- @Override
- public List<ModelFile> get() {
- final List<ModelFile> modelFiles = new ArrayList<>();
- // The update model has the highest precedence.
- if (mUpdatedModelFile.exists()) {
- final ModelFile updatedModel = createModelFile(mUpdatedModelFile);
- if (updatedModel != null) {
- modelFiles.add(updatedModel);
- }
- }
- // Factory models should never have overlapping locales, so the order doesn't matter.
- if (mFactoryModelDir.exists() && mFactoryModelDir.isDirectory()) {
- final File[] files = mFactoryModelDir.listFiles();
- for (File file : files) {
- final Matcher matcher = mModelFilenamePattern.matcher(file.getName());
- if (matcher.matches() && file.isFile()) {
- final ModelFile model = createModelFile(file);
- if (model != null) {
- modelFiles.add(model);
- }
- }
- }
- }
- return modelFiles;
- }
+ private final File file;
+ private final int version;
+ private final List<Locale> supportedLocales;
+ private final String supportedLocalesStr;
+ private final boolean languageIndependent;
- /** Returns null if the path did not point to a compatible model. */
- @Nullable
- private ModelFile createModelFile(File file) {
- if (!file.exists()) {
- return null;
- }
- ParcelFileDescriptor modelFd = null;
- try {
- modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- if (modelFd == null) {
- return null;
- }
- final int modelFdInt = modelFd.getFd();
- final int version = mVersionSupplier.apply(modelFdInt);
- final String supportedLocalesStr = mSupportedLocalesSupplier.apply(modelFdInt);
- if (supportedLocalesStr.isEmpty()) {
- TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
- return null;
- }
- final List<Locale> supportedLocales = new ArrayList<>();
- for (String langTag : supportedLocalesStr.split(",")) {
- supportedLocales.add(Locale.forLanguageTag(langTag));
- }
- return new ModelFile(
- file,
- version,
- supportedLocales,
- supportedLocalesStr,
- ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
- return null;
- } finally {
- maybeCloseAndLogError(modelFd);
- }
- }
-
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
- }
+ public ModelFile(
+ File file,
+ int version,
+ List<Locale> supportedLocales,
+ String supportedLocalesStr,
+ boolean languageIndependent) {
+ this.file = Preconditions.checkNotNull(file);
+ this.version = version;
+ this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
+ this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
+ this.languageIndependent = languageIndependent;
}
- /** Describes TextClassifier model files on disk. */
- public static final class ModelFile {
- public static final String LANGUAGE_INDEPENDENT = "*";
-
- private final File mFile;
- private final int mVersion;
- private final List<Locale> mSupportedLocales;
- private final String mSupportedLocalesStr;
- private final boolean mLanguageIndependent;
-
- public ModelFile(
- File file,
- int version,
- List<Locale> supportedLocales,
- String supportedLocalesStr,
- boolean languageIndependent) {
- mFile = Preconditions.checkNotNull(file);
- mVersion = version;
- mSupportedLocales = Preconditions.checkNotNull(supportedLocales);
- mSupportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
- mLanguageIndependent = languageIndependent;
- }
-
- /** Returns the absolute path to the model file. */
- public String getPath() {
- return mFile.getAbsolutePath();
- }
-
- /** Returns a name to use for id generation, effectively the name of the model file. */
- public String getName() {
- return mFile.getName();
- }
-
- /** Returns the version tag in the model's metadata. */
- public int getVersion() {
- return mVersion;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- Preconditions.checkNotNull(languageRanges);
- return mLanguageIndependent || Locale.lookup(languageRanges, mSupportedLocales) != null;
- }
-
- /** Returns an immutable lists of supported locales. */
- public List<Locale> getSupportedLocales() {
- return Collections.unmodifiableList(mSupportedLocales);
- }
-
- /** Returns the original supported locals string read from the model file. */
- public String getSupportedLocalesStr() {
- return mSupportedLocalesStr;
- }
-
- /** Returns if this model file is preferred to the given one. */
- public boolean isPreferredTo(@Nullable ModelFile model) {
- // A model is preferred to no model.
- if (model == null) {
- return true;
- }
-
- // A language-specific model is preferred to a language independent
- // model.
- if (!mLanguageIndependent && model.mLanguageIndependent) {
- return true;
- }
- if (mLanguageIndependent && !model.mLanguageIndependent) {
- return false;
- }
-
- // A higher-version model is preferred.
- if (mVersion > model.getVersion()) {
- return true;
- }
- return false;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(getPath());
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
- }
- if (other instanceof ModelFile) {
- final ModelFile otherModel = (ModelFile) other;
- return TextUtils.equals(getPath(), otherModel.getPath());
- }
- return false;
- }
-
- @Override
- public String toString() {
- final StringJoiner localesJoiner = new StringJoiner(",");
- for (Locale locale : mSupportedLocales) {
- localesJoiner.add(locale.toLanguageTag());
- }
- return String.format(
- Locale.US,
- "ModelFile { path=%s name=%s version=%d locales=%s }",
- getPath(),
- getName(),
- mVersion,
- localesJoiner.toString());
- }
+ /** Returns the absolute path to the model file. */
+ public String getPath() {
+ return file.getAbsolutePath();
}
+
+ /** Returns a name to use for id generation, effectively the name of the model file. */
+ public String getName() {
+ return file.getName();
+ }
+
+ /** Returns the version tag in the model's metadata. */
+ public int getVersion() {
+ return version;
+ }
+
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
+ }
+
+ /** Returns an immutable lists of supported locales. */
+ public List<Locale> getSupportedLocales() {
+ return Collections.unmodifiableList(supportedLocales);
+ }
+
+ /** Returns the original supported locals string read from the model file. */
+ public String getSupportedLocalesStr() {
+ return supportedLocalesStr;
+ }
+
+ /** Returns if this model file is preferred to the given one. */
+ public boolean isPreferredTo(@Nullable ModelFile model) {
+ // A model is preferred to no model.
+ if (model == null) {
+ return true;
+ }
+
+ // A language-specific model is preferred to a language independent
+ // model.
+ if (!languageIndependent && model.languageIndependent) {
+ return true;
+ }
+ if (languageIndependent && !model.languageIndependent) {
+ return false;
+ }
+
+ // A higher-version model is preferred.
+ if (version > model.getVersion()) {
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(getPath());
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (this == other) {
+ return true;
+ }
+ if (other instanceof ModelFile) {
+ final ModelFile otherModel = (ModelFile) other;
+ return TextUtils.equals(getPath(), otherModel.getPath());
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ final StringJoiner localesJoiner = new StringJoiner(",");
+ for (Locale locale : supportedLocales) {
+ localesJoiner.add(locale.toLanguageTag());
+ }
+ return String.format(
+ Locale.US,
+ "ModelFile { path=%s name=%s version=%d locales=%s }",
+ getPath(),
+ getName(),
+ version,
+ localesJoiner);
+ }
+ }
}
diff --git a/java/src/com/android/textclassifier/StringUtils.java b/java/src/com/android/textclassifier/StringUtils.java
index d4968c8..a32d5e4 100644
--- a/java/src/com/android/textclassifier/StringUtils.java
+++ b/java/src/com/android/textclassifier/StringUtils.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,14 +11,13 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier;
import androidx.annotation.GuardedBy;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.text.BreakIterator;
/**
@@ -31,52 +30,54 @@
* Intended to be used only for TextClassifier purposes.
*/
public final class StringUtils {
- private static final String TAG = "StringUtils";
+ private static final String TAG = "StringUtils";
- @GuardedBy("WORD_ITERATOR")
- private static final BreakIterator WORD_ITERATOR = BreakIterator.getWordInstance();
+ @GuardedBy("WORD_ITERATOR")
+ private static final BreakIterator WORD_ITERATOR = BreakIterator.getWordInstance();
- /**
- * Returns the substring of {@code text} that contains at least text from index {@code start}
- * <i>(inclusive)</i> to index {@code end} <i><(exclusive)/i> with the goal of returning text
- * that is at least {@code minimumLength}. If {@code text} is not long enough, this will return
- * {@code text}. This method returns text at word boundaries.
- *
- * @param text the source text
- * @param start the start index of text that must be included
- * @param end the end index of text that must be included
- * @param minimumLength minimum length of text to return if {@code text} is long enough
- */
- public static String getSubString(String text, int start, int end, int minimumLength) {
- Preconditions.checkArgument(start >= 0);
- Preconditions.checkArgument(end <= text.length());
- Preconditions.checkArgument(start <= end);
+ /**
+ * Returns the substring of {@code text} that contains at least text from index {@code start}
+ * <i>(inclusive)</i> to index {@code end} <i><(exclusive)/i> with the goal of returning text that
+ * is at least {@code minimumLength}. If {@code text} is not long enough, this will return {@code
+ * text}. This method returns text at word boundaries.
+ *
+ * @param text the source text
+ * @param start the start index of text that must be included
+ * @param end the end index of text that must be included
+ * @param minimumLength minimum length of text to return if {@code text} is long enough
+ */
+ public static String getSubString(String text, int start, int end, int minimumLength) {
+ Preconditions.checkArgument(start >= 0);
+ Preconditions.checkArgument(end <= text.length());
+ Preconditions.checkArgument(start <= end);
- if (text.length() < minimumLength) {
- return text;
- }
-
- final int length = end - start;
- if (length >= minimumLength) {
- return text.substring(start, end);
- }
-
- final int offset = (minimumLength - length) / 2;
- int iterStart = Math.max(0, Math.min(start - offset, text.length() - minimumLength));
- int iterEnd = Math.min(text.length(), iterStart + minimumLength);
-
- synchronized (WORD_ITERATOR) {
- WORD_ITERATOR.setText(text);
- iterStart =
- WORD_ITERATOR.isBoundary(iterStart)
- ? iterStart
- : Math.max(0, WORD_ITERATOR.preceding(iterStart));
- iterEnd =
- WORD_ITERATOR.isBoundary(iterEnd)
- ? iterEnd
- : Math.max(iterEnd, WORD_ITERATOR.following(iterEnd));
- WORD_ITERATOR.setText("");
- return text.substring(iterStart, iterEnd);
- }
+ if (text.length() < minimumLength) {
+ return text;
}
+
+ final int length = end - start;
+ if (length >= minimumLength) {
+ return text.substring(start, end);
+ }
+
+ final int offset = (minimumLength - length) / 2;
+ int iterStart = Math.max(0, Math.min(start - offset, text.length() - minimumLength));
+ int iterEnd = Math.min(text.length(), iterStart + minimumLength);
+
+ synchronized (WORD_ITERATOR) {
+ WORD_ITERATOR.setText(text);
+ iterStart =
+ WORD_ITERATOR.isBoundary(iterStart)
+ ? iterStart
+ : Math.max(0, WORD_ITERATOR.preceding(iterStart));
+ iterEnd =
+ WORD_ITERATOR.isBoundary(iterEnd)
+ ? iterEnd
+ : Math.max(iterEnd, WORD_ITERATOR.following(iterEnd));
+ WORD_ITERATOR.setText("");
+ return text.substring(iterStart, iterEnd);
+ }
+ }
+
+ private StringUtils() {}
}
diff --git a/java/src/com/android/textclassifier/TcLog.java b/java/src/com/android/textclassifier/TcLog.java
index 0c664ed..581d660 100644
--- a/java/src/com/android/textclassifier/TcLog.java
+++ b/java/src/com/android/textclassifier/TcLog.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -25,34 +25,34 @@
* @hide
*/
public final class TcLog {
- private static final boolean USE_TC_TAG = true;
- public static final String TAG = "androidtc";
+ private static final boolean USE_TC_TAG = true;
+ public static final String TAG = "androidtc";
- /** true: Enables full logging. false: Limits logging to debug level. */
- public static final boolean ENABLE_FULL_LOGGING =
- android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
+ /** true: Enables full logging. false: Limits logging to debug level. */
+ public static final boolean ENABLE_FULL_LOGGING =
+ android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
- private TcLog() {}
+ private TcLog() {}
- public static void v(String tag, String msg) {
- if (ENABLE_FULL_LOGGING) {
- android.util.Log.v(getTag(tag), msg);
- }
+ public static void v(String tag, String msg) {
+ if (ENABLE_FULL_LOGGING) {
+ android.util.Log.v(getTag(tag), msg);
}
+ }
- public static void d(String tag, String msg) {
- android.util.Log.d(getTag(tag), msg);
- }
+ public static void d(String tag, String msg) {
+ android.util.Log.d(getTag(tag), msg);
+ }
- public static void w(String tag, String msg) {
- android.util.Log.w(getTag(tag), msg);
- }
+ public static void w(String tag, String msg) {
+ android.util.Log.w(getTag(tag), msg);
+ }
- public static void e(String tag, String msg, Throwable tr) {
- android.util.Log.e(getTag(tag), msg, tr);
- }
+ public static void e(String tag, String msg, Throwable tr) {
+ android.util.Log.e(getTag(tag), msg, tr);
+ }
- private static String getTag(String customTag) {
- return USE_TC_TAG ? TAG : customTag;
- }
+ private static String getTag(String customTag) {
+ return USE_TC_TAG ? TAG : customTag;
+ }
}
diff --git a/java/src/com/android/textclassifier/TextClassificationConstants.java b/java/src/com/android/textclassifier/TextClassificationConstants.java
index f70c5ac..11b5179 100644
--- a/java/src/com/android/textclassifier/TextClassificationConstants.java
+++ b/java/src/com/android/textclassifier/TextClassificationConstants.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,14 +19,13 @@
import android.provider.DeviceConfig;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
-
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.utils.IndentingPrintWriter;
-
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import javax.annotation.Nullable;
/**
* TextClassifier specific settings.
@@ -43,318 +42,311 @@
*/
// TODO: Rename to TextClassifierSettings.
public final class TextClassificationConstants {
- private static final String DELIMITER = ":";
+ private static final String DELIMITER = ":";
- /** Whether the user language profile feature is enabled. */
- private static final String USER_LANGUAGE_PROFILE_ENABLED = "user_language_profile_enabled";
- /** Max length of text that suggestSelection can accept. */
- private static final String SUGGEST_SELECTION_MAX_RANGE_LENGTH =
- "suggest_selection_max_range_length";
- /** Max length of text that classifyText can accept. */
- private static final String CLASSIFY_TEXT_MAX_RANGE_LENGTH = "classify_text_max_range_length";
- /** Max length of text that generateLinks can accept. */
- private static final String GENERATE_LINKS_MAX_TEXT_LENGTH = "generate_links_max_text_length";
- /** Sampling rate for generateLinks logging. */
- private static final String GENERATE_LINKS_LOG_SAMPLE_RATE = "generate_links_log_sample_rate";
- /**
- * Extra count that is added to some languages, e.g. system languages, when deducing the
- * frequent languages in {@link
- * com.android.textclassifier.ulp.LanguageProfileAnalyzer#getFrequentLanguages(int)}.
- */
- private static final String FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT =
- "frequent_languages_bootstrapping_count";
+ /** Whether the user language profile feature is enabled. */
+ private static final String USER_LANGUAGE_PROFILE_ENABLED = "user_language_profile_enabled";
+ /** Max length of text that suggestSelection can accept. */
+ private static final String SUGGEST_SELECTION_MAX_RANGE_LENGTH =
+ "suggest_selection_max_range_length";
+ /** Max length of text that classifyText can accept. */
+ private static final String CLASSIFY_TEXT_MAX_RANGE_LENGTH = "classify_text_max_range_length";
+ /** Max length of text that generateLinks can accept. */
+ private static final String GENERATE_LINKS_MAX_TEXT_LENGTH = "generate_links_max_text_length";
+ /** Sampling rate for generateLinks logging. */
+ private static final String GENERATE_LINKS_LOG_SAMPLE_RATE = "generate_links_log_sample_rate";
+ /**
+ * Extra count that is added to some languages, e.g. system languages, when deducing the frequent
+ * languages in {@link
+ * com.android.textclassifier.ulp.LanguageProfileAnalyzer#getFrequentLanguages(int)}.
+ */
+ private static final String FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT =
+ "frequent_languages_bootstrapping_count";
- /**
- * Default count for the language in the system settings while calculating {@code
- * LanguageProfileAnalyzer.getRecognizedLanguages()}
- */
- private static final String LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT =
- "language_proficiency_bootstrapping_count";
+ /**
+ * Default count for the language in the system settings while calculating {@code
+ * LanguageProfileAnalyzer.getRecognizedLanguages()}
+ */
+ private static final String LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT =
+ "language_proficiency_bootstrapping_count";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * hint is not given.
- */
- private static final String ENTITY_LIST_DEFAULT = "entity_list_default";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * the text is in a not editable UI widget.
- */
- private static final String ENTITY_LIST_NOT_EDITABLE = "entity_list_not_editable";
- /**
- * A colon(:) separated string that specifies the default entities types for generateLinks when
- * the text is in an editable UI widget.
- */
- private static final String ENTITY_LIST_EDITABLE = "entity_list_editable";
- /**
- * A colon(:) separated string that specifies the default action types for
- * suggestConversationActions when the suggestions are used in an app.
- */
- private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
- "in_app_conversation_action_types_default";
- /**
- * A colon(:) separated string that specifies the default action types for
- * suggestConversationActions when the suggestions are used in a notification.
- */
- private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
- "notification_conversation_action_types_default";
- /** Threshold to accept a suggested language from LangID model. */
- private static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
- /** Whether to enable {@link com.android.textclassifier.intent.TemplateIntentFactory}. */
- private static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
- /** Whether to enable "translate" action in classifyText. */
- private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
- "translate_in_classification_enabled";
- /**
- * Whether to detect the languages of the text in request by using langId for the native model.
- */
- private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
- "detect_languages_from_text_enabled";
- /**
- * A colon(:) separated string that specifies the configuration to use when including
- * surrounding context text in language detection queries.
- *
- * <p>Format= minimumTextSize<int>:penalizeRatio<float>:textScoreRatio<float>
- *
- * <p>e.g. 20:1.0:0.4
- *
- * <p>Accept all text lengths with minimumTextSize=0
- *
- * <p>Reject all text less than minimumTextSize with penalizeRatio=0
- *
- * @see {@code TextClassifierImpl#detectLanguages(String, int, int)} for reference.
- */
- private static final String LANG_ID_CONTEXT_SETTINGS = "lang_id_context_settings";
- /** Default threshold to translate the language of the context the user selects */
- private static final String TRANSLATE_ACTION_THRESHOLD = "translate_action_threshold";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * hint is not given.
+ */
+ private static final String ENTITY_LIST_DEFAULT = "entity_list_default";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in a not editable UI widget.
+ */
+ private static final String ENTITY_LIST_NOT_EDITABLE = "entity_list_not_editable";
+ /**
+ * A colon(:) separated string that specifies the default entities types for generateLinks when
+ * the text is in an editable UI widget.
+ */
+ private static final String ENTITY_LIST_EDITABLE = "entity_list_editable";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in an app.
+ */
+ private static final String IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "in_app_conversation_action_types_default";
+ /**
+ * A colon(:) separated string that specifies the default action types for
+ * suggestConversationActions when the suggestions are used in a notification.
+ */
+ private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
+ "notification_conversation_action_types_default";
+ /** Threshold to accept a suggested language from LangID model. */
+ private static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
+ /** Whether to enable {@link com.android.textclassifier.intent.TemplateIntentFactory}. */
+ private static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
+ /** Whether to enable "translate" action in classifyText. */
+ private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
+ "translate_in_classification_enabled";
+ /**
+ * Whether to detect the languages of the text in request by using langId for the native model.
+ */
+ private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
+ "detect_languages_from_text_enabled";
+ /**
+ * A colon(:) separated string that specifies the configuration to use when including surrounding
+ * context text in language detection queries.
+ *
+ * <p>Format= minimumTextSize<int>:penalizeRatio<float>:textScoreRatio<float>
+ *
+ * <p>e.g. 20:1.0:0.4
+ *
+ * <p>Accept all text lengths with minimumTextSize=0
+ *
+ * <p>Reject all text less than minimumTextSize with penalizeRatio=0
+ *
+ * @see {@code TextClassifierImpl#detectLanguages(String, int, int)} for reference.
+ */
+ private static final String LANG_ID_CONTEXT_SETTINGS = "lang_id_context_settings";
+ /** Default threshold to translate the language of the context the user selects */
+ private static final String TRANSLATE_ACTION_THRESHOLD = "translate_action_threshold";
- // Sync this with ConversationAction.TYPE_ADD_CONTACT;
- public static final String TYPE_ADD_CONTACT = "add_contact";
- // Sync this with ConversationAction.COPY;
- public static final String TYPE_COPY = "copy";
+ // Sync this with ConversationAction.TYPE_ADD_CONTACT;
+ public static final String TYPE_ADD_CONTACT = "add_contact";
+ // Sync this with ConversationAction.COPY;
+ public static final String TYPE_COPY = "copy";
- private static final int SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
- private static final int CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
- private static final int GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT = 100 * 1000;
- private static final int GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT = 100;
- private static final int FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT = 100;
- private static final int LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT = 100;
+ private static final int SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT = 10 * 1000;
+ private static final int GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT = 100 * 1000;
+ private static final int GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT = 100;
+ private static final int FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT = 100;
+ private static final int LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT = 100;
- private static final List<String> ENTITY_LIST_DEFAULT_VALUE =
- Arrays.asList(
- TextClassifier.TYPE_ADDRESS,
- TextClassifier.TYPE_EMAIL,
- TextClassifier.TYPE_PHONE,
- TextClassifier.TYPE_URL,
- TextClassifier.TYPE_DATE,
- TextClassifier.TYPE_DATE_TIME,
- TextClassifier.TYPE_FLIGHT_NUMBER);
- private static final List<String> CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
- Arrays.asList(
- ConversationAction.TYPE_TEXT_REPLY,
- ConversationAction.TYPE_CREATE_REMINDER,
- ConversationAction.TYPE_CALL_PHONE,
- ConversationAction.TYPE_OPEN_URL,
- ConversationAction.TYPE_SEND_EMAIL,
- ConversationAction.TYPE_SEND_SMS,
- ConversationAction.TYPE_TRACK_FLIGHT,
- ConversationAction.TYPE_VIEW_CALENDAR,
- ConversationAction.TYPE_VIEW_MAP,
- TYPE_ADD_CONTACT,
- TYPE_COPY);
- /**
- * < 0 : Not set. Use value from LangId model. 0 - 1: Override value in LangId model.
- *
- * @see EntityConfidence
- */
- private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
+ private static final ImmutableList<String> ENTITY_LIST_DEFAULT_VALUE =
+ ImmutableList.of(
+ TextClassifier.TYPE_ADDRESS,
+ TextClassifier.TYPE_EMAIL,
+ TextClassifier.TYPE_PHONE,
+ TextClassifier.TYPE_URL,
+ TextClassifier.TYPE_DATE,
+ TextClassifier.TYPE_DATE_TIME,
+ TextClassifier.TYPE_FLIGHT_NUMBER);
+ private static final ImmutableList<String> CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES =
+ ImmutableList.of(
+ ConversationAction.TYPE_TEXT_REPLY,
+ ConversationAction.TYPE_CREATE_REMINDER,
+ ConversationAction.TYPE_CALL_PHONE,
+ ConversationAction.TYPE_OPEN_URL,
+ ConversationAction.TYPE_SEND_EMAIL,
+ ConversationAction.TYPE_SEND_SMS,
+ ConversationAction.TYPE_TRACK_FLIGHT,
+ ConversationAction.TYPE_VIEW_CALENDAR,
+ ConversationAction.TYPE_VIEW_MAP,
+ TYPE_ADD_CONTACT,
+ TYPE_COPY);
+ /**
+ * < 0 : Not set. Use value from LangId model. 0 - 1: Override value in LangId model.
+ *
+ * @see EntityConfidence
+ */
+ private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
- private static final float TRANSLATE_ACTION_THRESHOLD_DEFAULT = 0.5f;
+ private static final float TRANSLATE_ACTION_THRESHOLD_DEFAULT = 0.5f;
- private static final boolean USER_LANGUAGE_PROFILE_ENABLED_DEFAULT = true;
- private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
- private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
- private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
- private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
+ private static final boolean USER_LANGUAGE_PROFILE_ENABLED_DEFAULT = true;
+ private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
+ private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
+ private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
+ private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
- public int getSuggestSelectionMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH,
- SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
+ public int getSuggestSelectionMaxRangeLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ SUGGEST_SELECTION_MAX_RANGE_LENGTH,
+ SUGGEST_SELECTION_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getClassifyTextMaxRangeLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CLASSIFY_TEXT_MAX_RANGE_LENGTH,
+ CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksMaxTextLength() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ GENERATE_LINKS_MAX_TEXT_LENGTH,
+ GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
+ }
+
+ public int getGenerateLinksLogSampleRate() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ GENERATE_LINKS_LOG_SAMPLE_RATE,
+ GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
+ }
+
+ public int getFrequentLanguagesBootstrappingCount() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT,
+ FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT);
+ }
+
+ public int getLanguageProficiencyBootstrappingCount() {
+ return DeviceConfig.getInt(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT,
+ LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT);
+ }
+
+ public List<String> getEntityListDefault() {
+ return getDeviceConfigStringList(ENTITY_LIST_DEFAULT, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListNotEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_NOT_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getEntityListEditable() {
+ return getDeviceConfigStringList(ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
+ }
+
+ public List<String> getInAppConversationActionTypes() {
+ return getDeviceConfigStringList(
+ IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public List<String> getNotificationConversationActionTypes() {
+ return getDeviceConfigStringList(
+ NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT, CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
+ }
+
+ public float getLangIdThresholdOverride() {
+ return DeviceConfig.getFloat(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ LANG_ID_THRESHOLD_OVERRIDE,
+ LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
+ }
+
+ public float getTranslateActionThreshold() {
+ return DeviceConfig.getFloat(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TRANSLATE_ACTION_THRESHOLD,
+ TRANSLATE_ACTION_THRESHOLD_DEFAULT);
+ }
+
+ public boolean isUserLanguageProfileEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ USER_LANGUAGE_PROFILE_ENABLED,
+ USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
+ }
+
+ public boolean isTemplateIntentFactoryEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TEMPLATE_INTENT_FACTORY_ENABLED,
+ TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
+ }
+
+ public boolean isTranslateInClassificationEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED,
+ TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
+ }
+
+ public boolean isDetectLanguagesFromTextEnabled() {
+ return DeviceConfig.getBoolean(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ DETECT_LANGUAGES_FROM_TEXT_ENABLED,
+ DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
+ }
+
+ public float[] getLangIdContextSettings() {
+ return getDeviceConfigFloatArray(LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
+ }
+
+ void dump(IndentingPrintWriter pw) {
+ pw.println("TextClassificationConstants:");
+ pw.increaseIndent();
+ pw.printPair("classify_text_max_range_length", getClassifyTextMaxRangeLength());
+ pw.printPair("detect_language_from_text_enabled", isDetectLanguagesFromTextEnabled());
+ pw.printPair("entity_list_default", getEntityListDefault());
+ pw.printPair("entity_list_editable", getEntityListEditable());
+ pw.printPair("entity_list_not_editable", getEntityListNotEditable());
+ pw.printPair("generate_links_log_sample_rate", getGenerateLinksLogSampleRate());
+ pw.printPair(
+ "frequent_languages_bootstrapping_count", getFrequentLanguagesBootstrappingCount());
+ pw.printPair("generate_links_max_text_length", getGenerateLinksMaxTextLength());
+ pw.printPair("in_app_conversation_action_types_default", getInAppConversationActionTypes());
+ pw.printPair("lang_id_context_settings", Arrays.toString(getLangIdContextSettings()));
+ pw.printPair("lang_id_threshold_override", getLangIdThresholdOverride());
+ pw.printPair("translate_action_threshold", getTranslateActionThreshold());
+ pw.printPair(
+ "notification_conversation_action_types_default", getNotificationConversationActionTypes());
+ pw.printPair("suggest_selection_max_range_length", getSuggestSelectionMaxRangeLength());
+ pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
+ pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
+ pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
+ pw.printPair(
+ "language proficiency bootstrapping count", getLanguageProficiencyBootstrappingCount());
+ pw.decreaseIndent();
+ }
+
+ private static List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
+ return parse(
+ DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ }
+
+ private static float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
+ return parse(
+ DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null), defaultValue);
+ }
+
+ private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
+ if (listStr != null) {
+ return Collections.unmodifiableList(Arrays.asList(listStr.split(DELIMITER)));
}
+ return defaultValue;
+ }
- public int getClassifyTextMaxRangeLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH,
- CLASSIFY_TEXT_MAX_RANGE_LENGTH_DEFAULT);
- }
-
- public int getGenerateLinksMaxTextLength() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_MAX_TEXT_LENGTH,
- GENERATE_LINKS_MAX_TEXT_LENGTH_DEFAULT);
- }
-
- public int getGenerateLinksLogSampleRate() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- GENERATE_LINKS_LOG_SAMPLE_RATE,
- GENERATE_LINKS_LOG_SAMPLE_RATE_DEFAULT);
- }
-
- public int getFrequentLanguagesBootstrappingCount() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT,
- FREQUENT_LANGUAGES_BOOTSTRAPPING_COUNT_DEFAULT);
- }
-
- public int getLanguageProficiencyBootstrappingCount() {
- return DeviceConfig.getInt(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT,
- LANGUAGE_PROFICIENCY_BOOTSTRAPPING_COUNT_DEFAULT);
- }
-
- public List<String> getEntityListDefault() {
- return getDeviceConfigStringList(ENTITY_LIST_DEFAULT, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getEntityListNotEditable() {
- return getDeviceConfigStringList(ENTITY_LIST_NOT_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getEntityListEditable() {
- return getDeviceConfigStringList(ENTITY_LIST_EDITABLE, ENTITY_LIST_DEFAULT_VALUE);
- }
-
- public List<String> getInAppConversationActionTypes() {
- return getDeviceConfigStringList(
- IN_APP_CONVERSATION_ACTION_TYPES_DEFAULT,
- CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
- }
-
- public List<String> getNotificationConversationActionTypes() {
- return getDeviceConfigStringList(
- NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT,
- CONVERSATION_ACTIONS_TYPES_DEFAULT_VALUES);
- }
-
- public float getLangIdThresholdOverride() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- LANG_ID_THRESHOLD_OVERRIDE,
- LANG_ID_THRESHOLD_OVERRIDE_DEFAULT);
- }
-
- public float getTranslateActionThreshold() {
- return DeviceConfig.getFloat(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TRANSLATE_ACTION_THRESHOLD,
- TRANSLATE_ACTION_THRESHOLD_DEFAULT);
- }
-
- public boolean isUserLanguageProfileEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- USER_LANGUAGE_PROFILE_ENABLED,
- USER_LANGUAGE_PROFILE_ENABLED_DEFAULT);
- }
-
- public boolean isTemplateIntentFactoryEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TEMPLATE_INTENT_FACTORY_ENABLED,
- TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
- }
-
- public boolean isTranslateInClassificationEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- TRANSLATE_IN_CLASSIFICATION_ENABLED,
- TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
- }
-
- public boolean isDetectLanguagesFromTextEnabled() {
- return DeviceConfig.getBoolean(
- DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED,
- DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
- }
-
- public float[] getLangIdContextSettings() {
- return getDeviceConfigFloatArray(
- LANG_ID_CONTEXT_SETTINGS, LANG_ID_CONTEXT_SETTINGS_DEFAULT);
- }
-
- void dump(IndentingPrintWriter pw) {
- pw.println("TextClassificationConstants:");
- pw.increaseIndent();
- pw.printPair("classify_text_max_range_length", getClassifyTextMaxRangeLength());
- pw.printPair("detect_language_from_text_enabled", isDetectLanguagesFromTextEnabled());
- pw.printPair("entity_list_default", getEntityListDefault());
- pw.printPair("entity_list_editable", getEntityListEditable());
- pw.printPair("entity_list_not_editable", getEntityListNotEditable());
- pw.printPair("generate_links_log_sample_rate", getGenerateLinksLogSampleRate());
- pw.printPair(
- "frequent_languages_bootstrapping_count", getFrequentLanguagesBootstrappingCount());
- pw.printPair("generate_links_max_text_length", getGenerateLinksMaxTextLength());
- pw.printPair("in_app_conversation_action_types_default", getInAppConversationActionTypes());
- pw.printPair("lang_id_context_settings", Arrays.toString(getLangIdContextSettings()));
- pw.printPair("lang_id_threshold_override", getLangIdThresholdOverride());
- pw.printPair("translate_action_threshold", getTranslateActionThreshold());
- pw.printPair(
- "notification_conversation_action_types_default",
- getNotificationConversationActionTypes());
- pw.printPair("suggest_selection_max_range_length", getSuggestSelectionMaxRangeLength());
- pw.printPair("user_language_profile_enabled", isUserLanguageProfileEnabled());
- pw.printPair("template_intent_factory_enabled", isTemplateIntentFactoryEnabled());
- pw.printPair("translate_in_classification_enabled", isTranslateInClassificationEnabled());
- pw.printPair(
- "language proficiency bootstrapping count",
- getLanguageProficiencyBootstrappingCount());
- pw.decreaseIndent();
- }
-
- private static List<String> getDeviceConfigStringList(String key, List<String> defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null),
- defaultValue);
- }
-
- private static float[] getDeviceConfigFloatArray(String key, float[] defaultValue) {
- return parse(
- DeviceConfig.getString(DeviceConfig.NAMESPACE_TEXTCLASSIFIER, key, null),
- defaultValue);
- }
-
- private static List<String> parse(@Nullable String listStr, List<String> defaultValue) {
- if (listStr != null) {
- return Collections.unmodifiableList(Arrays.asList(listStr.split(DELIMITER)));
- }
+ private static float[] parse(@Nullable String arrayStr, float[] defaultValue) {
+ if (arrayStr != null) {
+ final List<String> split = Splitter.onPattern(DELIMITER).splitToList(arrayStr);
+ if (split.size() != defaultValue.length) {
return defaultValue;
- }
-
- private static float[] parse(@Nullable String arrayStr, float[] defaultValue) {
- if (arrayStr != null) {
- final String[] split = arrayStr.split(DELIMITER);
- if (split.length != defaultValue.length) {
- return defaultValue;
- }
- final float[] result = new float[split.length];
- for (int i = 0; i < split.length; i++) {
- try {
- result[i] = Float.parseFloat(split[i]);
- } catch (NumberFormatException e) {
- return defaultValue;
- }
- }
- return result;
- } else {
- return defaultValue;
+ }
+ final float[] result = new float[split.size()];
+ for (int i = 0; i < split.size(); i++) {
+ try {
+ result[i] = Float.parseFloat(split.get(i));
+ } catch (NumberFormatException e) {
+ return defaultValue;
}
+ }
+ return result;
+ } else {
+ return defaultValue;
}
+ }
}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index f385684..9babf30 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -37,15 +37,10 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
-
import androidx.annotation.GuardedBy;
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
import androidx.annotation.WorkerThread;
import androidx.collection.ArraySet;
import androidx.core.util.Pair;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.ActionsModelParamsSupplier.ActionsModelParams;
import com.android.textclassifier.intent.ClassificationIntentFactory;
import com.android.textclassifier.intent.LabeledIntent;
@@ -59,17 +54,16 @@
import com.android.textclassifier.ulp.LanguageProfileAnalyzer;
import com.android.textclassifier.ulp.LanguageProfileUpdater;
import com.android.textclassifier.utils.IndentingPrintWriter;
-
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
-import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.base.Preconditions;
import com.google.common.util.concurrent.MoreExecutors;
-
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.time.Instant;
+import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
import java.util.Collection;
@@ -80,6 +74,7 @@
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.function.Supplier;
+import javax.annotation.Nullable;
/**
* Default implementation of the {@link TextClassifier} interface.
@@ -90,901 +85,872 @@
*/
final class TextClassifierImpl {
- private static final String TAG = "TextClassifierImpl";
+ private static final String TAG = "TextClassifierImpl";
- private static final boolean DEBUG = false;
+ private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
+ // Annotator
+ private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
+ "textclassifier\\.(.*)\\.model";
+ private static final File ANNOTATOR_UPDATED_MODEL_FILE =
+ new File("/data/misc/textclassifier/textclassifier.model");
- private static final File FACTORY_MODEL_DIR = new File("/etc/textclassifier/");
- // Annotator
- private static final String ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX =
- "textclassifier\\.(.*)\\.model";
- private static final File ANNOTATOR_UPDATED_MODEL_FILE =
- new File("/data/misc/textclassifier/textclassifier.model");
+ // LangID
+ private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
+ private static final File UPDATED_LANG_ID_MODEL_FILE =
+ new File("/data/misc/textclassifier/lang_id.model");
- // LangID
- private static final String LANG_ID_FACTORY_MODEL_FILENAME_REGEX = "lang_id.model";
- private static final File UPDATED_LANG_ID_MODEL_FILE =
- new File("/data/misc/textclassifier/lang_id.model");
+ // Actions
+ private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
+ "actions_suggestions\\.(.*)\\.model";
+ private static final File UPDATED_ACTIONS_MODEL =
+ new File("/data/misc/textclassifier/actions_suggestions.model");
- // Actions
- private static final String ACTIONS_FACTORY_MODEL_FILENAME_REGEX =
- "actions_suggestions\\.(.*)\\.model";
- private static final File UPDATED_ACTIONS_MODEL =
- new File("/data/misc/textclassifier/actions_suggestions.model");
+ private final Context context;
+ private final TextClassifier fallback;
+ private final GenerateLinksLogger generateLinksLogger;
- private final Context mContext;
- private final TextClassifier mFallback;
- private final GenerateLinksLogger mGenerateLinksLogger;
+ private final Object lock = new Object();
- private final Object mLock = new Object();
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile annotatorModelInUse;
- @GuardedBy("mLock")
- private ModelFileManager.ModelFile mAnnotatorModelInUse;
+ @GuardedBy("lock")
+ private AnnotatorModel annotatorImpl;
- @GuardedBy("mLock")
- private AnnotatorModel mAnnotatorImpl;
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile langIdModelInUse;
- @GuardedBy("mLock")
- private ModelFileManager.ModelFile mLangIdModelInUse;
+ @GuardedBy("lock")
+ private LangIdModel langIdImpl;
- @GuardedBy("mLock")
- private LangIdModel mLangIdImpl;
+ @GuardedBy("lock")
+ private ModelFileManager.ModelFile actionModelInUse;
- @GuardedBy("mLock")
- private ModelFileManager.ModelFile mActionModelInUse;
+ @GuardedBy("lock")
+ private ActionsSuggestionsModel actionsImpl;
- @GuardedBy("mLock")
- private ActionsSuggestionsModel mActionsImpl;
+ private final TextClassifierEventLogger textClassifierEventLogger =
+ new TextClassifierEventLogger();
- private final TextClassifierEventLogger mTextClassifierEventLogger =
- new TextClassifierEventLogger();
+ private final TextClassificationConstants settings;
- private final TextClassificationConstants mSettings;
+ private final ModelFileManager annotatorModelFileManager;
+ private final ModelFileManager langIdModelFileManager;
+ private final ModelFileManager actionsModelFileManager;
+ private final LanguageProfileUpdater languageProfileUpdater;
+ private final LanguageProfileAnalyzer languageProfileAnalyzer;
+ private final ClassificationIntentFactory classificationIntentFactory;
+ private final TemplateIntentFactory templateIntentFactory;
+ private final Supplier<ActionsModelParams> actionsModelParamsSupplier;
- private final ModelFileManager mAnnotatorModelFileManager;
- private final ModelFileManager mLangIdModelFileManager;
- private final ModelFileManager mActionsModelFileManager;
- private final LanguageProfileUpdater mLanguageProfileUpdater;
- private final LanguageProfileAnalyzer mLanguageProfileAnalyzer;
- private final ClassificationIntentFactory mClassificationIntentFactory;
- private final TemplateIntentFactory mTemplateIntentFactory;
- private final Supplier<ActionsModelParams> mActionsModelParamsSupplier;
+ TextClassifierImpl(
+ Context context, TextClassificationConstants settings, TextClassifier fallback) {
+ this.context = Preconditions.checkNotNull(context);
+ this.fallback = Preconditions.checkNotNull(fallback);
+ this.settings = Preconditions.checkNotNull(settings);
+ generateLinksLogger = new GenerateLinksLogger(this.settings.getGenerateLinksLogSampleRate());
+ annotatorModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
+ ANNOTATOR_UPDATED_MODEL_FILE,
+ AnnotatorModel::getVersion,
+ AnnotatorModel::getLocales));
+ langIdModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_LANG_ID_MODEL_FILE,
+ LangIdModel::getVersion,
+ fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
+ actionsModelFileManager =
+ new ModelFileManager(
+ new ModelFileManager.ModelFileSupplierImpl(
+ FACTORY_MODEL_DIR,
+ ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
+ UPDATED_ACTIONS_MODEL,
+ ActionsSuggestionsModel::getVersion,
+ ActionsSuggestionsModel::getLocales));
+ languageProfileUpdater =
+ new LanguageProfileUpdater(
+ this.context, MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()));
+ languageProfileAnalyzer = LanguageProfileAnalyzer.create(context, this.settings);
- TextClassifierImpl(
- Context context, TextClassificationConstants settings, TextClassifier fallback) {
- mContext = Preconditions.checkNotNull(context);
- mFallback = Preconditions.checkNotNull(fallback);
- mSettings = Preconditions.checkNotNull(settings);
- mGenerateLinksLogger = new GenerateLinksLogger(mSettings.getGenerateLinksLogSampleRate());
- mAnnotatorModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ANNOTATOR_FACTORY_MODEL_FILENAME_REGEX,
- ANNOTATOR_UPDATED_MODEL_FILE,
- AnnotatorModel::getVersion,
- AnnotatorModel::getLocales));
- mLangIdModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- LANG_ID_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_LANG_ID_MODEL_FILE,
- LangIdModel::getVersion,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT));
- mActionsModelFileManager =
- new ModelFileManager(
- new ModelFileManager.ModelFileSupplierImpl(
- FACTORY_MODEL_DIR,
- ACTIONS_FACTORY_MODEL_FILENAME_REGEX,
- UPDATED_ACTIONS_MODEL,
- ActionsSuggestionsModel::getVersion,
- ActionsSuggestionsModel::getLocales));
- mLanguageProfileUpdater =
- new LanguageProfileUpdater(
- mContext,
- MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()));
- mLanguageProfileAnalyzer = LanguageProfileAnalyzer.create(context, mSettings);
+ templateIntentFactory = new TemplateIntentFactory();
+ classificationIntentFactory =
+ this.settings.isTemplateIntentFactoryEnabled()
+ ? new TemplateClassificationIntentFactory(
+ templateIntentFactory, new LegacyClassificationIntentFactory())
+ : new LegacyClassificationIntentFactory();
+ actionsModelParamsSupplier =
+ new ActionsModelParamsSupplier(
+ this.context,
+ () -> {
+ synchronized (lock) {
+ // Clear actionsImpl here, so that we will create a new
+ // ActionsSuggestionsModel object with the new flag in the next
+ // request.
+ actionsImpl = null;
+ actionModelInUse = null;
+ }
+ });
+ }
- mTemplateIntentFactory = new TemplateIntentFactory();
- mClassificationIntentFactory =
- mSettings.isTemplateIntentFactoryEnabled()
- ? new TemplateClassificationIntentFactory(
- mTemplateIntentFactory, new LegacyClassificationIntentFactory())
- : new LegacyClassificationIntentFactory();
- mActionsModelParamsSupplier =
- new ActionsModelParamsSupplier(
- mContext,
- () -> {
- synchronized (mLock) {
- // Clear mActionsImpl here, so that we will create a new
- // ActionsSuggestionsModel object with the new flag in the next
- // request.
- mActionsImpl = null;
- mActionModelInUse = null;
- }
- });
- }
+ TextClassifierImpl(Context context, TextClassificationConstants settings) {
+ this(context, settings, TextClassifier.NO_OP);
+ }
- TextClassifierImpl(Context context, TextClassificationConstants settings) {
- this(context, settings, TextClassifier.NO_OP);
- }
-
- @WorkerThread
- TextSelection suggestSelection(TextSelection.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0
- && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final String detectLanguageTags =
- String.join(
- ",",
- detectLanguages(request.getText(), getLangIdThreshold())
- .getEntities());
- final ZonedDateTime refTime = ZonedDateTime.now();
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final int[] startEnd =
- annotatorImpl.suggestSelection(
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- new AnnotatorModel.SelectionOptions(
- localesString, detectLanguageTags));
- final int start = startEnd[0];
- final int end = startEnd[1];
- if (start < end
- && start >= 0
- && end <= string.length()
- && start <= request.getStartIndex()
- && end >= request.getEndIndex()) {
- final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
- final AnnotatorModel.ClassificationResult[] results =
- annotatorImpl.classifyText(
- string,
- start,
- end,
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- detectLanguageTags),
- // Passing null here to suppress intent generation
- // TODO: Use an explicit flag to suppress it.
- /* appContext */ null,
- /* deviceLocales */ null);
- final int size = results.length;
- for (int i = 0; i < size; i++) {
- tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
- }
- return tsBuilder
- .setId(createId(string, request.getStartIndex(), request.getEndIndex()))
- .build();
- } else {
- // We can not trust the result. Log the issue and ignore the result.
- TcLog.d(TAG, "Got bad indices for input text. Ignoring result.");
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(
- TAG,
- "Error suggesting selection for text. No changes to selection suggested.",
- t);
+ @WorkerThread
+ TextSelection suggestSelection(TextSelection.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ if (string.length() > 0 && rangeLength <= settings.getSuggestSelectionMaxRangeLength()) {
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final String detectLanguageTags =
+ String.join(
+ ",", detectLanguages(request.getText(), getLangIdThreshold()).getEntities());
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final int[] startEnd =
+ annotatorImpl.suggestSelection(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
+ final int start = startEnd[0];
+ final int end = startEnd[1];
+ if (start < end
+ && start >= 0
+ && end <= string.length()
+ && start <= request.getStartIndex()
+ && end >= request.getEndIndex()) {
+ final TextSelection.Builder tsBuilder = new TextSelection.Builder(start, end);
+ final AnnotatorModel.ClassificationResult[] results =
+ annotatorImpl.classifyText(
+ string,
+ start,
+ end,
+ new AnnotatorModel.ClassificationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ detectLanguageTags),
+ // Passing null here to suppress intent generation
+ // TODO: Use an explicit flag to suppress it.
+ /* appContext */ null,
+ /* deviceLocales */ null);
+ final int size = results.length;
+ for (int i = 0; i < size; i++) {
+ tsBuilder.setEntityType(results[i].getCollection(), results[i].getScore());
+ }
+ return tsBuilder
+ .setId(createId(string, request.getStartIndex(), request.getEndIndex()))
+ .build();
+ } else {
+ // We can not trust the result. Log the issue and ignore the result.
+ TcLog.d(TAG, "Got bad indices for input text. Ignoring result.");
}
- // Getting here means something went wrong, return a NO_OP result.
- return mFallback.suggestSelection(request);
+ }
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error suggesting selection for text. No changes to selection suggested.", t);
}
+ // Getting here means something went wrong, return a NO_OP result.
+ return fallback.suggestSelection(request);
+ }
- @WorkerThread
- TextClassification classifyText(TextClassification.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- List<String> detectLanguageTags =
- detectLanguages(request.getText(), getLangIdThreshold()).getEntities();
- if (mSettings.isUserLanguageProfileEnabled()) {
- ListenableFuture<Void> ignoredResult =
- mLanguageProfileUpdater.updateFromClassifyTextAsync(detectLanguageTags);
- }
- final int rangeLength = request.getEndIndex() - request.getStartIndex();
- final String string = request.getText().toString();
- if (string.length() > 0 && rangeLength <= mSettings.getClassifyTextMaxRangeLength()) {
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final ZonedDateTime refTime =
- request.getReferenceTime() != null
- ? request.getReferenceTime()
- : ZonedDateTime.now();
- final AnnotatorModel.ClassificationResult[] results =
- getAnnotatorImpl(request.getDefaultLocales())
- .classifyText(
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- new AnnotatorModel.ClassificationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- String.join(",", detectLanguageTags)),
- mContext,
- getResourceLocalesString());
- if (results.length > 0) {
- return createClassificationResult(
- results,
- string,
- request.getStartIndex(),
- request.getEndIndex(),
- refTime.toInstant());
- }
- }
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting text classification info.", t);
+ @WorkerThread
+ TextClassification classifyText(TextClassification.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ List<String> detectLanguageTags =
+ detectLanguages(request.getText(), getLangIdThreshold()).getEntities();
+ if (settings.isUserLanguageProfileEnabled()) {
+ languageProfileUpdater.updateFromClassifyTextAsync(detectLanguageTags);
+ }
+ final int rangeLength = request.getEndIndex() - request.getStartIndex();
+ final String string = request.getText().toString();
+ if (string.length() > 0 && rangeLength <= settings.getClassifyTextMaxRangeLength()) {
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final ZonedDateTime refTime =
+ request.getReferenceTime() != null
+ ? request.getReferenceTime()
+ : ZonedDateTime.now(ZoneId.systemDefault());
+ final AnnotatorModel.ClassificationResult[] results =
+ getAnnotatorImpl(request.getDefaultLocales())
+ .classifyText(
+ string,
+ request.getStartIndex(),
+ request.getEndIndex(),
+ new AnnotatorModel.ClassificationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ String.join(",", detectLanguageTags)),
+ context,
+ getResourceLocalesString());
+ if (results.length > 0) {
+ return createClassificationResult(
+ results, string, request.getStartIndex(), request.getEndIndex(), refTime.toInstant());
}
- // Getting here means something went wrong, return a NO_OP result.
- return mFallback.classifyText(request);
+ }
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error getting text classification info.", t);
}
+ // Getting here means something went wrong, return a NO_OP result.
+ return fallback.classifyText(request);
+ }
- @WorkerThread
- TextLinks generateLinks(@NonNull TextLinks.Request request) {
- Preconditions.checkNotNull(request);
- Preconditions.checkArgumentInRange(
- request.getText().length(), 0, getMaxGenerateLinksTextLength(), "text.length()");
- checkMainThread();
+ @WorkerThread
+ TextLinks generateLinks(TextLinks.Request request) {
+ Preconditions.checkNotNull(request);
+ Preconditions.checkArgument(
+ request.getText().length() <= getMaxGenerateLinksTextLength(),
+ "text.length() cannot be greater than %s",
+ getMaxGenerateLinksTextLength());
+ checkMainThread();
- final String textString = request.getText().toString();
- final TextLinks.Builder builder = new TextLinks.Builder(textString);
+ final String textString = request.getText().toString();
+ final TextLinks.Builder builder = new TextLinks.Builder(textString);
- try {
- final long startTimeMs = System.currentTimeMillis();
- final ZonedDateTime refTime = ZonedDateTime.now();
- final Collection<String> entitiesToIdentify =
- request.getEntityConfig() != null
- ? request.getEntityConfig()
- .resolveEntityListModifications(
- getEntitiesForHints(
- request.getEntityConfig().getHints()))
- : mSettings.getEntityListDefault();
- final String localesString = concatenateLocales(request.getDefaultLocales());
- final String detectLanguageTags =
- String.join(
- ",",
- detectLanguages(request.getText(), getLangIdThreshold()).getEntities());
- final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
- final boolean isSerializedEntityDataEnabled =
- ExtrasUtils.isSerializedEntityDataEnabled(request);
- final AnnotatorModel.AnnotatedSpan[] annotations =
- annotatorImpl.annotate(
- textString,
- new AnnotatorModel.AnnotationOptions(
- refTime.toInstant().toEpochMilli(),
- refTime.getZone().getId(),
- localesString,
- detectLanguageTags,
- entitiesToIdentify,
- AnnotatorModel.AnnotationUsecase.SMART.getValue(),
- isSerializedEntityDataEnabled));
- for (AnnotatorModel.AnnotatedSpan span : annotations) {
- final AnnotatorModel.ClassificationResult[] results = span.getClassification();
- if (results.length == 0
- || !entitiesToIdentify.contains(results[0].getCollection())) {
- continue;
- }
- final Map<String, Float> entityScores = new ArrayMap<>();
- for (int i = 0; i < results.length; i++) {
- entityScores.put(results[i].getCollection(), results[i].getScore());
- }
- Bundle extras = new Bundle();
- if (isSerializedEntityDataEnabled) {
- ExtrasUtils.putEntities(extras, results);
- }
- builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
- }
- final TextLinks links = builder.build();
- final long endTimeMs = System.currentTimeMillis();
- final String callingPackageName =
- request.getCallingPackageName() == null
- ? mContext.getPackageName() // local (in process) TC.
- : request.getCallingPackageName();
- mGenerateLinksLogger.logGenerateLinks(
- request.getText(), links, callingPackageName, endTimeMs - startTimeMs);
- return links;
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error getting links info.", t);
+ try {
+ final long startTimeMs = System.currentTimeMillis();
+ final ZonedDateTime refTime = ZonedDateTime.now(ZoneId.systemDefault());
+ final Collection<String> entitiesToIdentify =
+ request.getEntityConfig() != null
+ ? request
+ .getEntityConfig()
+ .resolveEntityListModifications(
+ getEntitiesForHints(request.getEntityConfig().getHints()))
+ : settings.getEntityListDefault();
+ final String localesString = concatenateLocales(request.getDefaultLocales());
+ final String detectLanguageTags =
+ String.join(",", detectLanguages(request.getText(), getLangIdThreshold()).getEntities());
+ final AnnotatorModel annotatorImpl = getAnnotatorImpl(request.getDefaultLocales());
+ final boolean isSerializedEntityDataEnabled =
+ ExtrasUtils.isSerializedEntityDataEnabled(request);
+ final AnnotatorModel.AnnotatedSpan[] annotations =
+ annotatorImpl.annotate(
+ textString,
+ new AnnotatorModel.AnnotationOptions(
+ refTime.toInstant().toEpochMilli(),
+ refTime.getZone().getId(),
+ localesString,
+ detectLanguageTags,
+ entitiesToIdentify,
+ AnnotatorModel.AnnotationUsecase.SMART.getValue(),
+ isSerializedEntityDataEnabled));
+ for (AnnotatorModel.AnnotatedSpan span : annotations) {
+ final AnnotatorModel.ClassificationResult[] results = span.getClassification();
+ if (results.length == 0 || !entitiesToIdentify.contains(results[0].getCollection())) {
+ continue;
}
- return mFallback.generateLinks(request);
- }
-
- int getMaxGenerateLinksTextLength() {
- return mSettings.getGenerateLinksMaxTextLength();
- }
-
- private Collection<String> getEntitiesForHints(Collection<String> hints) {
- final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE);
- final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE);
-
- // Use the default if there is no hint, or conflicting ones.
- final boolean useDefault = editable == notEditable;
- if (useDefault) {
- return mSettings.getEntityListDefault();
- } else if (editable) {
- return mSettings.getEntityListEditable();
- } else { // notEditable
- return mSettings.getEntityListNotEditable();
- }
- }
-
- void onSelectionEvent(SelectionEvent event) {
- TextClassifierEvent textClassifierEvent =
- SelectionEventConverter.toTextClassifierEvent(event);
- if (textClassifierEvent == null) {
- return;
- }
- onTextClassifierEvent(event.getSessionId(), textClassifierEvent);
- }
-
- void onTextClassifierEvent(
- @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
- mTextClassifierEventLogger.writeEvent(sessionId, event);
- if (mSettings.isUserLanguageProfileEnabled()) {
- mLanguageProfileAnalyzer.onTextClassifierEven(event);
- }
- }
-
- TextLanguage detectLanguage(@NonNull TextLanguage.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- try {
- final TextLanguage.Builder builder = new TextLanguage.Builder();
- final LangIdModel.LanguageResult[] langResults =
- getLangIdImpl().detectLanguages(request.getText().toString());
- for (int i = 0; i < langResults.length; i++) {
- builder.putLocale(
- ULocale.forLanguageTag(langResults[i].getLanguage()),
- langResults[i].getScore());
- }
- return builder.build();
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error detecting text language.", t);
- }
- return mFallback.detectLanguage(request);
- }
-
- ConversationActions suggestConversationActions(ConversationActions.Request request) {
- Preconditions.checkNotNull(request);
- checkMainThread();
- if (mSettings.isUserLanguageProfileEnabled()) {
- // FIXME: Reuse the LangID result.
- ListenableFuture<Void> ignoredResult =
- mLanguageProfileUpdater.updateFromConversationActionsAsync(
- request,
- text -> detectLanguages(text, getLangIdThreshold()).getEntities());
- }
- try {
- ActionsSuggestionsModel actionsImpl = getActionsImpl();
- if (actionsImpl == null) {
- // Actions model is optional, fallback if it is not available.
- return mFallback.suggestConversationActions(request);
- }
- ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
- ActionsSuggestionsHelper.toNativeMessages(
- request.getConversation(),
- text -> detectLanguages(text, getLangIdThreshold()).getEntities());
- if (nativeMessages.length == 0) {
- return mFallback.suggestConversationActions(request);
- }
- ActionsSuggestionsModel.Conversation nativeConversation =
- new ActionsSuggestionsModel.Conversation(nativeMessages);
-
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
- actionsImpl.suggestActionsWithIntents(
- nativeConversation,
- null,
- mContext,
- getResourceLocalesString(),
- getAnnotatorImpl(LocaleList.getDefault()));
- return createConversationActionResult(request, nativeSuggestions);
- } catch (Throwable t) {
- // Avoid throwing from this method. Log the error.
- TcLog.e(TAG, "Error suggesting conversation actions.", t);
- }
- return mFallback.suggestConversationActions(request);
- }
-
- /**
- * Returns the {@link ConversationAction} result, with a non-null extras.
- *
- * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a
- * non-null component name is in the extras.
- */
- private ConversationActions createConversationActionResult(
- ConversationActions.Request request,
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
- Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
- List<ConversationAction> conversationActions = new ArrayList<>();
- for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
- String actionType = nativeSuggestion.getActionType();
- if (!expectedTypes.contains(actionType)) {
- continue;
- }
- LabeledIntent.Result labeledIntentResult =
- ActionsSuggestionsHelper.createLabeledIntentResult(
- mContext, mTemplateIntentFactory, nativeSuggestion);
- RemoteAction remoteAction = null;
- Bundle extras = new Bundle();
- if (labeledIntentResult != null) {
- remoteAction = labeledIntentResult.remoteAction;
- ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
- }
- ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
- ExtrasUtils.putEntitiesExtras(
- extras,
- TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
- conversationActions.add(
- new ConversationAction.Builder(actionType)
- .setConfidenceScore(nativeSuggestion.getScore())
- .setTextReply(nativeSuggestion.getResponseText())
- .setAction(remoteAction)
- .setExtras(extras)
- .build());
- }
- conversationActions =
- ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions);
- if (request.getMaxSuggestions() >= 0
- && conversationActions.size() > request.getMaxSuggestions()) {
- conversationActions = conversationActions.subList(0, request.getMaxSuggestions());
- }
- synchronized (mLock) {
- String resultId =
- ActionsSuggestionsHelper.createResultId(
- mContext,
- request.getConversation(),
- mActionModelInUse.getVersion(),
- mActionModelInUse.getSupportedLocales());
- return new ConversationActions(conversationActions, resultId);
- }
- }
-
- private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
- List<String> defaultActionTypes =
- request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION)
- ? mSettings.getNotificationConversationActionTypes()
- : mSettings.getInAppConversationActionTypes();
- return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
- }
-
- private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException {
- synchronized (mLock) {
- localeList = localeList == null ? LocaleList.getDefault() : localeList;
- final ModelFileManager.ModelFile bestModel =
- mAnnotatorModelFileManager.findBestModelFile(localeList);
- if (bestModel == null) {
- throw new FileNotFoundException(
- "No annotator model for " + localeList.toLanguageTags());
- }
- if (mAnnotatorImpl == null || !Objects.equals(mAnnotatorModelInUse, bestModel)) {
- TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd != null) {
- // The current annotator model may be still used by another thread / model.
- // Do not call close() here, and let the GC to clean it up when no one else
- // is using it.
- mAnnotatorImpl = new AnnotatorModel(pfd.getFd());
- mAnnotatorModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
- }
- }
- return mAnnotatorImpl;
- }
- }
-
- private LangIdModel getLangIdImpl() throws FileNotFoundException {
- synchronized (mLock) {
- final ModelFileManager.ModelFile bestModel =
- mLangIdModelFileManager.findBestModelFile(null);
- if (bestModel == null) {
- throw new FileNotFoundException("No LangID model is found");
- }
- if (mLangIdImpl == null || !Objects.equals(mLangIdModelInUse, bestModel)) {
- TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd != null) {
- mLangIdImpl = new LangIdModel(pfd.getFd());
- mLangIdModelInUse = bestModel;
- }
- } finally {
- maybeCloseAndLogError(pfd);
- }
- }
- return mLangIdImpl;
- }
- }
-
- @Nullable
- private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
- synchronized (mLock) {
- // TODO: Use LangID to determine the locale we should use here?
- final ModelFileManager.ModelFile bestModel =
- mActionsModelFileManager.findBestModelFile(LocaleList.getDefault());
- if (bestModel == null) {
- return null;
- }
- if (mActionsImpl == null || !Objects.equals(mActionModelInUse, bestModel)) {
- TcLog.d(TAG, "Loading " + bestModel);
- final ParcelFileDescriptor pfd =
- ParcelFileDescriptor.open(
- new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
- try {
- if (pfd == null) {
- TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath());
- return null;
- }
- ActionsModelParams params = mActionsModelParamsSupplier.get();
- mActionsImpl =
- new ActionsSuggestionsModel(
- pfd.getFd(), params.getSerializedPreconditions(bestModel));
- mActionModelInUse = bestModel;
- } finally {
- maybeCloseAndLogError(pfd);
- }
- }
- return mActionsImpl;
- }
- }
-
- private String createId(String text, int start, int end) {
- synchronized (mLock) {
- return ResultIdUtils.createId(
- mContext,
- text,
- start,
- end,
- mAnnotatorModelInUse.getVersion(),
- mAnnotatorModelInUse.getSupportedLocales());
- }
- }
-
- private static String concatenateLocales(@Nullable LocaleList locales) {
- return (locales == null) ? "" : locales.toLanguageTags();
- }
-
- private TextClassification createClassificationResult(
- AnnotatorModel.ClassificationResult[] classifications,
- String text,
- int start,
- int end,
- @Nullable Instant referenceTime) {
- final String classifiedText = text.substring(start, end);
- final TextClassification.Builder builder =
- new TextClassification.Builder().setText(classifiedText);
-
- final int typeCount = classifications.length;
- AnnotatorModel.ClassificationResult highestScoringResult =
- typeCount > 0 ? classifications[0] : null;
- for (int i = 0; i < typeCount; i++) {
- builder.setEntityType(
- classifications[i].getCollection(), classifications[i].getScore());
- if (classifications[i].getScore() > highestScoringResult.getScore()) {
- highestScoringResult = classifications[i];
- }
- }
- final Pair<Bundle, Bundle> languagesBundles = generateLanguageBundles(text, start, end);
- final Bundle textLanguagesBundle = languagesBundles.first;
- final Bundle foreignLanguageBundle = languagesBundles.second;
-
- boolean isPrimaryAction = true;
- final List<LabeledIntent> labeledIntents =
- mClassificationIntentFactory.create(
- mContext,
- classifiedText,
- foreignLanguageBundle != null,
- referenceTime,
- highestScoringResult);
- final LabeledIntent.TitleChooser titleChooser =
- (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity;
-
- ArrayList<Intent> actionIntents = new ArrayList<>();
- for (LabeledIntent labeledIntent : labeledIntents) {
- final LabeledIntent.Result result =
- labeledIntent.resolve(mContext, titleChooser, textLanguagesBundle);
- if (result == null) {
- continue;
- }
-
- final Intent intent = result.resolvedIntent;
- final RemoteAction action = result.remoteAction;
- if (isPrimaryAction) {
- // For O backwards compatibility, the first RemoteAction is also written to the
- // legacy API fields.
- builder.setIcon(action.getIcon().loadDrawable(mContext));
- builder.setLabel(action.getTitle().toString());
- builder.setIntent(intent);
- builder.setOnClickListener(
- createIntentOnClickListener(
- createPendingIntent(mContext, intent, labeledIntent.requestCode)));
- isPrimaryAction = false;
- }
- builder.addAction(action);
- actionIntents.add(intent);
+ final Map<String, Float> entityScores = new ArrayMap<>();
+ for (int i = 0; i < results.length; i++) {
+ entityScores.put(results[i].getCollection(), results[i].getScore());
}
Bundle extras = new Bundle();
- ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageBundle);
- if (actionIntents.stream().anyMatch(Objects::nonNull)) {
- ExtrasUtils.putActionsIntents(extras, actionIntents);
+ if (isSerializedEntityDataEnabled) {
+ ExtrasUtils.putEntities(extras, results);
}
- ExtrasUtils.putEntities(extras, classifications);
- builder.setExtras(extras);
- return builder.setId(createId(text, start, end)).build();
+ builder.addLink(span.getStartIndex(), span.getEndIndex(), entityScores, extras);
+ }
+ final TextLinks links = builder.build();
+ final long endTimeMs = System.currentTimeMillis();
+ final String callingPackageName =
+ request.getCallingPackageName() == null
+ ? context.getPackageName() // local (in process) TC.
+ : request.getCallingPackageName();
+ generateLinksLogger.logGenerateLinks(
+ request.getText(), links, callingPackageName, endTimeMs - startTimeMs);
+ return links;
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error getting links info.", t);
}
+ return fallback.generateLinks(request);
+ }
- private static OnClickListener createIntentOnClickListener(
- @NonNull final PendingIntent intent) {
- Preconditions.checkNotNull(intent);
- return v -> {
- try {
- intent.send();
- } catch (PendingIntent.CanceledException e) {
- TcLog.e(TAG, "Error sending PendingIntent", e);
- }
- };
+ int getMaxGenerateLinksTextLength() {
+ return settings.getGenerateLinksMaxTextLength();
+ }
+
+ private Collection<String> getEntitiesForHints(Collection<String> hints) {
+ final boolean editable = hints.contains(TextClassifier.HINT_TEXT_IS_EDITABLE);
+ final boolean notEditable = hints.contains(TextClassifier.HINT_TEXT_IS_NOT_EDITABLE);
+
+ // Use the default if there is no hint, or conflicting ones.
+ final boolean useDefault = editable == notEditable;
+ if (useDefault) {
+ return settings.getEntityListDefault();
+ } else if (editable) {
+ return settings.getEntityListEditable();
+ } else { // notEditable
+ return settings.getEntityListNotEditable();
}
+ }
- /**
- * Returns a bundle pair with language detection information for extras.
- *
- * <p>Pair.first = textLanguagesBundle - A bundle containing information about all detected
- * languages in the text. May be null if language detection fails or is disabled. This is
- * typically expected to be added to a textClassifier generated remote action intent. See {@link
- * ExtrasUtils#putTextLanguagesExtra(Bundle, Bundle)}. See {@link
- * ExtrasUtils#getTopLanguage(Intent)}.
- *
- * <p>Pair.second = foreignLanguageBundle - A bundle with the language and confidence score if
- * the system finds the text to be in a foreign language. Otherwise is null. See {@link
- * TextClassification.Builder#setForeignLanguageExtra(Bundle)}.
- *
- * @param context the context of the text to detect languages for
- * @param start the start index of the text
- * @param end the end index of the text
- */
- // TODO: Revisit this algorithm.
- // TODO: Consider making this public API.
- private Pair<Bundle, Bundle> generateLanguageBundles(String context, int start, int end) {
- if (!mSettings.isTranslateInClassificationEnabled()) {
+ void onSelectionEvent(SelectionEvent event) {
+ TextClassifierEvent textClassifierEvent = SelectionEventConverter.toTextClassifierEvent(event);
+ if (textClassifierEvent == null) {
+ return;
+ }
+ onTextClassifierEvent(event.getSessionId(), textClassifierEvent);
+ }
+
+ void onTextClassifierEvent(
+ @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ textClassifierEventLogger.writeEvent(sessionId, event);
+ if (settings.isUserLanguageProfileEnabled()) {
+ languageProfileAnalyzer.onTextClassifierEven(event);
+ }
+ }
+
+ TextLanguage detectLanguage(TextLanguage.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ try {
+ final TextLanguage.Builder builder = new TextLanguage.Builder();
+ final LangIdModel.LanguageResult[] langResults =
+ getLangIdImpl().detectLanguages(request.getText().toString());
+ for (int i = 0; i < langResults.length; i++) {
+ builder.putLocale(
+ ULocale.forLanguageTag(langResults[i].getLanguage()), langResults[i].getScore());
+ }
+ return builder.build();
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error detecting text language.", t);
+ }
+ return fallback.detectLanguage(request);
+ }
+
+ ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ Preconditions.checkNotNull(request);
+ checkMainThread();
+ if (settings.isUserLanguageProfileEnabled()) {
+ // TODO(tonymak): Reuse the LangID result.
+ languageProfileUpdater.updateFromConversationActionsAsync(
+ request, text -> detectLanguages(text, getLangIdThreshold()).getEntities());
+ }
+ try {
+ ActionsSuggestionsModel actionsImpl = getActionsImpl();
+ if (actionsImpl == null) {
+ // Actions model is optional, fallback if it is not available.
+ return fallback.suggestConversationActions(request);
+ }
+ ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
+ ActionsSuggestionsHelper.toNativeMessages(
+ request.getConversation(),
+ text -> detectLanguages(text, getLangIdThreshold()).getEntities());
+ if (nativeMessages.length == 0) {
+ return fallback.suggestConversationActions(request);
+ }
+ ActionsSuggestionsModel.Conversation nativeConversation =
+ new ActionsSuggestionsModel.Conversation(nativeMessages);
+
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
+ actionsImpl.suggestActionsWithIntents(
+ nativeConversation,
+ null,
+ context,
+ getResourceLocalesString(),
+ getAnnotatorImpl(LocaleList.getDefault()));
+ return createConversationActionResult(request, nativeSuggestions);
+ } catch (Throwable t) {
+ // Avoid throwing from this method. Log the error.
+ TcLog.e(TAG, "Error suggesting conversation actions.", t);
+ }
+ return fallback.suggestConversationActions(request);
+ }
+
+ /**
+ * Returns the {@link ConversationAction} result, with a non-null extras.
+ *
+ * <p>Whenever the RemoteAction is non-null, you can expect its corresponding intent with a
+ * non-null component name is in the extras.
+ */
+ private ConversationActions createConversationActionResult(
+ ConversationActions.Request request,
+ ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
+ Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
+ List<ConversationAction> conversationActions = new ArrayList<>();
+ for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
+ String actionType = nativeSuggestion.getActionType();
+ if (!expectedTypes.contains(actionType)) {
+ continue;
+ }
+ LabeledIntent.Result labeledIntentResult =
+ ActionsSuggestionsHelper.createLabeledIntentResult(
+ context, templateIntentFactory, nativeSuggestion);
+ RemoteAction remoteAction = null;
+ Bundle extras = new Bundle();
+ if (labeledIntentResult != null) {
+ remoteAction = labeledIntentResult.remoteAction;
+ ExtrasUtils.putActionIntent(extras, labeledIntentResult.resolvedIntent);
+ }
+ ExtrasUtils.putSerializedEntityData(extras, nativeSuggestion.getSerializedEntityData());
+ ExtrasUtils.putEntitiesExtras(
+ extras, TemplateIntentFactory.nameVariantsToBundle(nativeSuggestion.getEntityData()));
+ conversationActions.add(
+ new ConversationAction.Builder(actionType)
+ .setConfidenceScore(nativeSuggestion.getScore())
+ .setTextReply(nativeSuggestion.getResponseText())
+ .setAction(remoteAction)
+ .setExtras(extras)
+ .build());
+ }
+ conversationActions = ActionsSuggestionsHelper.removeActionsWithDuplicates(conversationActions);
+ if (request.getMaxSuggestions() >= 0
+ && conversationActions.size() > request.getMaxSuggestions()) {
+ conversationActions = conversationActions.subList(0, request.getMaxSuggestions());
+ }
+ synchronized (lock) {
+ String resultId =
+ ActionsSuggestionsHelper.createResultId(
+ context,
+ request.getConversation(),
+ actionModelInUse.getVersion(),
+ actionModelInUse.getSupportedLocales());
+ return new ConversationActions(conversationActions, resultId);
+ }
+ }
+
+ private Collection<String> resolveActionTypesFromRequest(ConversationActions.Request request) {
+ List<String> defaultActionTypes =
+ request.getHints().contains(ConversationActions.Request.HINT_FOR_NOTIFICATION)
+ ? settings.getNotificationConversationActionTypes()
+ : settings.getInAppConversationActionTypes();
+ return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
+ }
+
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws FileNotFoundException {
+ synchronized (lock) {
+ localeList = localeList == null ? LocaleList.getDefault() : localeList;
+ final ModelFileManager.ModelFile bestModel =
+ annotatorModelFileManager.findBestModelFile(localeList);
+ if (bestModel == null) {
+ throw new FileNotFoundException("No annotator model for " + localeList.toLanguageTags());
+ }
+ if (annotatorImpl == null || !Objects.equals(annotatorModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd != null) {
+ // The current annotator model may be still used by another thread / model.
+ // Do not call close() here, and let the GC to clean it up when no one else
+ // is using it.
+ annotatorImpl = new AnnotatorModel(pfd.getFd());
+ annotatorModelInUse = bestModel;
+ }
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return annotatorImpl;
+ }
+ }
+
+ private LangIdModel getLangIdImpl() throws FileNotFoundException {
+ synchronized (lock) {
+ final ModelFileManager.ModelFile bestModel = langIdModelFileManager.findBestModelFile(null);
+ if (bestModel == null) {
+ throw new FileNotFoundException("No LangID model is found");
+ }
+ if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd != null) {
+ langIdImpl = new LangIdModel(pfd.getFd());
+ langIdModelInUse = bestModel;
+ }
+ } finally {
+ maybeCloseAndLogError(pfd);
+ }
+ }
+ return langIdImpl;
+ }
+ }
+
+ @Nullable
+ private ActionsSuggestionsModel getActionsImpl() throws FileNotFoundException {
+ synchronized (lock) {
+ // TODO: Use LangID to determine the locale we should use here?
+ final ModelFileManager.ModelFile bestModel =
+ actionsModelFileManager.findBestModelFile(LocaleList.getDefault());
+ if (bestModel == null) {
+ return null;
+ }
+ if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
+ TcLog.d(TAG, "Loading " + bestModel);
+ final ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(
+ new File(bestModel.getPath()), ParcelFileDescriptor.MODE_READ_ONLY);
+ try {
+ if (pfd == null) {
+ TcLog.d(TAG, "Failed to read the model file: " + bestModel.getPath());
return null;
+ }
+ ActionsModelParams params = actionsModelParamsSupplier.get();
+ actionsImpl =
+ new ActionsSuggestionsModel(
+ pfd.getFd(), params.getSerializedPreconditions(bestModel));
+ actionModelInUse = bestModel;
+ } finally {
+ maybeCloseAndLogError(pfd);
}
- try {
- final float threshold = getLangIdThreshold();
- if (threshold < 0 || threshold > 1) {
- TcLog.w(TAG, "[detectForeignLanguage] unexpected threshold is found: " + threshold);
- return Pair.create(null, null);
- }
+ }
+ return actionsImpl;
+ }
+ }
- final EntityConfidence languageScores = detectLanguages(context, start, end);
- if (languageScores.getEntities().isEmpty()) {
- return Pair.create(null, null);
- }
+ private String createId(String text, int start, int end) {
+ synchronized (lock) {
+ return ResultIdUtils.createId(
+ context,
+ text,
+ start,
+ end,
+ annotatorModelInUse.getVersion(),
+ annotatorModelInUse.getSupportedLocales());
+ }
+ }
- final Bundle textLanguagesBundle = new Bundle();
- ExtrasUtils.putTopLanguageScores(textLanguagesBundle, languageScores);
+ private static String concatenateLocales(@Nullable LocaleList locales) {
+ return (locales == null) ? "" : locales.toLanguageTags();
+ }
- final String language = languageScores.getEntities().get(0);
- final float score = languageScores.getConfidenceScore(language);
- if (score < threshold) {
- return Pair.create(textLanguagesBundle, null);
- }
+ private TextClassification createClassificationResult(
+ AnnotatorModel.ClassificationResult[] classifications,
+ String text,
+ int start,
+ int end,
+ @Nullable Instant referenceTime) {
+ final String classifiedText = text.substring(start, end);
+ final TextClassification.Builder builder =
+ new TextClassification.Builder().setText(classifiedText);
- TcLog.v(TAG, String.format(Locale.US, "Language detected: <%s:%.2f>", language, score));
- if (mSettings.isUserLanguageProfileEnabled()) {
- if (!mLanguageProfileAnalyzer.shouldShowTranslation(language)) {
- return Pair.create(textLanguagesBundle, null);
- }
- } else {
- final Locale detected = new Locale(language);
- final LocaleList deviceLocales = LocaleList.getDefault();
- final int size = deviceLocales.size();
- for (int i = 0; i < size; i++) {
- if (deviceLocales.get(i).getLanguage().equals(detected.getLanguage())) {
- return Pair.create(textLanguagesBundle, null);
- }
- }
- }
- final Bundle foreignLanguageBundle =
- ExtrasUtils.createForeignLanguageExtra(
- language, score, getLangIdImpl().getVersion());
- return Pair.create(textLanguagesBundle, foreignLanguageBundle);
- } catch (Throwable t) {
- TcLog.e(TAG, "Error generating language bundles.", t);
- }
+ final int typeCount = classifications.length;
+ AnnotatorModel.ClassificationResult highestScoringResult =
+ typeCount > 0 ? classifications[0] : null;
+ for (int i = 0; i < typeCount; i++) {
+ builder.setEntityType(classifications[i].getCollection(), classifications[i].getScore());
+ if (classifications[i].getScore() > highestScoringResult.getScore()) {
+ highestScoringResult = classifications[i];
+ }
+ }
+ final Pair<Bundle, Bundle> languagesBundles = generateLanguageBundles(text, start, end);
+ final Bundle textLanguagesBundle = languagesBundles.first;
+ final Bundle foreignLanguageBundle = languagesBundles.second;
+
+ boolean isPrimaryAction = true;
+ final List<LabeledIntent> labeledIntents =
+ classificationIntentFactory.create(
+ context,
+ classifiedText,
+ foreignLanguageBundle != null,
+ referenceTime,
+ highestScoringResult);
+ final LabeledIntent.TitleChooser titleChooser =
+ (labeledIntent, resolveInfo) -> labeledIntent.titleWithoutEntity;
+
+ ArrayList<Intent> actionIntents = new ArrayList<>();
+ for (LabeledIntent labeledIntent : labeledIntents) {
+ final LabeledIntent.Result result =
+ labeledIntent.resolve(context, titleChooser, textLanguagesBundle);
+ if (result == null) {
+ continue;
+ }
+
+ final Intent intent = result.resolvedIntent;
+ final RemoteAction action = result.remoteAction;
+ if (isPrimaryAction) {
+ // For O backwards compatibility, the first RemoteAction is also written to the
+ // legacy API fields.
+ builder.setIcon(action.getIcon().loadDrawable(context));
+ builder.setLabel(action.getTitle().toString());
+ builder.setIntent(intent);
+ builder.setOnClickListener(
+ createIntentOnClickListener(
+ createPendingIntent(context, intent, labeledIntent.requestCode)));
+ isPrimaryAction = false;
+ }
+ builder.addAction(action);
+ actionIntents.add(intent);
+ }
+ Bundle extras = new Bundle();
+ ExtrasUtils.putForeignLanguageExtra(extras, foreignLanguageBundle);
+ if (actionIntents.stream().anyMatch(Objects::nonNull)) {
+ ExtrasUtils.putActionsIntents(extras, actionIntents);
+ }
+ ExtrasUtils.putEntities(extras, classifications);
+ builder.setExtras(extras);
+ return builder.setId(createId(text, start, end)).build();
+ }
+
+ private static OnClickListener createIntentOnClickListener(final PendingIntent intent) {
+ Preconditions.checkNotNull(intent);
+ return v -> {
+ try {
+ intent.send();
+ } catch (PendingIntent.CanceledException e) {
+ TcLog.e(TAG, "Error sending PendingIntent", e);
+ }
+ };
+ }
+
+ /**
+ * Returns a bundle pair with language detection information for extras.
+ *
+ * <p>Pair.first = textLanguagesBundle - A bundle containing information about all detected
+ * languages in the text. May be null if language detection fails or is disabled. This is
+ * typically expected to be added to a textClassifier generated remote action intent. See {@link
+ * ExtrasUtils#putTextLanguagesExtra(Bundle, Bundle)}. See {@link
+ * ExtrasUtils#getTopLanguage(Intent)}.
+ *
+ * <p>Pair.second = foreignLanguageBundle - A bundle with the language and confidence score if the
+ * system finds the text to be in a foreign language. Otherwise is null. See {@link
+ * TextClassification.Builder#setForeignLanguageExtra(Bundle)}.
+ *
+ * @param context the context of the text to detect languages for
+ * @param start the start index of the text
+ * @param end the end index of the text
+ */
+ // TODO: Revisit this algorithm.
+ // TODO: Consider making this public API.
+ private Pair<Bundle, Bundle> generateLanguageBundles(String context, int start, int end) {
+ if (!settings.isTranslateInClassificationEnabled()) {
+ return null;
+ }
+ try {
+ final float threshold = getLangIdThreshold();
+ if (threshold < 0 || threshold > 1) {
+ TcLog.w(TAG, "[detectForeignLanguage] unexpected threshold is found: " + threshold);
return Pair.create(null, null);
+ }
+
+ final EntityConfidence languageScores = detectLanguages(context, start, end);
+ if (languageScores.getEntities().isEmpty()) {
+ return Pair.create(null, null);
+ }
+
+ final Bundle textLanguagesBundle = new Bundle();
+ ExtrasUtils.putTopLanguageScores(textLanguagesBundle, languageScores);
+
+ final String language = languageScores.getEntities().get(0);
+ final float score = languageScores.getConfidenceScore(language);
+ if (score < threshold) {
+ return Pair.create(textLanguagesBundle, null);
+ }
+
+ TcLog.v(TAG, String.format(Locale.US, "Language detected: <%s:%.2f>", language, score));
+ if (settings.isUserLanguageProfileEnabled()) {
+ if (!languageProfileAnalyzer.shouldShowTranslation(language)) {
+ return Pair.create(textLanguagesBundle, null);
+ }
+ } else {
+ final Locale detected = new Locale(language);
+ final LocaleList deviceLocales = LocaleList.getDefault();
+ final int size = deviceLocales.size();
+ for (int i = 0; i < size; i++) {
+ if (deviceLocales.get(i).getLanguage().equals(detected.getLanguage())) {
+ return Pair.create(textLanguagesBundle, null);
+ }
+ }
+ }
+ final Bundle foreignLanguageBundle =
+ ExtrasUtils.createForeignLanguageExtra(language, score, getLangIdImpl().getVersion());
+ return Pair.create(textLanguagesBundle, foreignLanguageBundle);
+ } catch (Throwable t) {
+ TcLog.e(TAG, "Error generating language bundles.", t);
+ }
+ return Pair.create(null, null);
+ }
+
+ /**
+ * Detect the language of a piece of text by taking surrounding text into consideration.
+ *
+ * @param text text providing context for the text for which its language is to be detected
+ * @param start the start index of the text to detect its language
+ * @param end the end index of the text to detect its language
+ */
+ // TODO: Revisit this algorithm.
+ private EntityConfidence detectLanguages(String text, int start, int end) {
+ Preconditions.checkArgument(start >= 0);
+ Preconditions.checkArgument(end <= text.length());
+ Preconditions.checkArgument(start <= end);
+
+ final float[] langIdContextSettings = settings.getLangIdContextSettings();
+ // The minimum size of text to prefer for detection.
+ final int minimumTextSize = (int) langIdContextSettings[0];
+ // For reducing the score when text is less than the preferred size.
+ final float penalizeRatio = langIdContextSettings[1];
+ // Original detection score to surrounding text detection score ratios.
+ final float subjectTextScoreRatio = langIdContextSettings[2];
+ final float moreTextScoreRatio = 1f - subjectTextScoreRatio;
+ TcLog.v(
+ TAG,
+ String.format(
+ Locale.US,
+ "LangIdContextSettings: "
+ + "minimumTextSize=%d, penalizeRatio=%.2f, "
+ + "subjectTextScoreRatio=%.2f, moreTextScoreRatio=%.2f",
+ minimumTextSize,
+ penalizeRatio,
+ subjectTextScoreRatio,
+ moreTextScoreRatio));
+
+ if (end - start < minimumTextSize && penalizeRatio <= 0) {
+ return EntityConfidence.EMPTY;
}
- /**
- * Detect the language of a piece of text by taking surrounding text into consideration.
- *
- * @param text text providing context for the text for which its language is to be detected
- * @param start the start index of the text to detect its language
- * @param end the end index of the text to detect its language
- */
- // TODO: Revisit this algorithm.
- private EntityConfidence detectLanguages(String text, int start, int end) {
- Preconditions.checkArgument(start >= 0);
- Preconditions.checkArgument(end <= text.length());
- Preconditions.checkArgument(start <= end);
+ final String subject = text.substring(start, end);
+ final EntityConfidence scores = detectLanguages(subject, /* threshold= */ 0f);
- final float[] langIdContextSettings = mSettings.getLangIdContextSettings();
- // The minimum size of text to prefer for detection.
- final int minimumTextSize = (int) langIdContextSettings[0];
- // For reducing the score when text is less than the preferred size.
- final float penalizeRatio = langIdContextSettings[1];
- // Original detection score to surrounding text detection score ratios.
- final float subjectTextScoreRatio = langIdContextSettings[2];
- final float moreTextScoreRatio = 1f - subjectTextScoreRatio;
- TcLog.v(
- TAG,
- String.format(
- Locale.US,
- "LangIdContextSettings: "
- + "minimumTextSize=%d, penalizeRatio=%.2f, "
- + "subjectTextScoreRatio=%.2f, moreTextScoreRatio=%.2f",
- minimumTextSize,
- penalizeRatio,
- subjectTextScoreRatio,
- moreTextScoreRatio));
-
- if (end - start < minimumTextSize && penalizeRatio <= 0) {
- return EntityConfidence.EMPTY;
- }
-
- final String subject = text.substring(start, end);
- final EntityConfidence scores = detectLanguages(subject, /* threshold= */ 0f);
-
- if (subject.length() >= minimumTextSize
- || subject.length() == text.length()
- || subjectTextScoreRatio * penalizeRatio >= 1) {
- return scores;
- }
-
- final EntityConfidence moreTextScores;
- if (moreTextScoreRatio >= 0) {
- // Attempt to grow the detection text to be at least minimumTextSize long.
- final String moreText = StringUtils.getSubString(text, start, end, minimumTextSize);
- moreTextScores = detectLanguages(moreText, /* threshold= */ 0f);
- } else {
- moreTextScores = EntityConfidence.EMPTY;
- }
-
- // Combine the original detection scores with the those returned after including more text.
- final Map<String, Float> newScores = new ArrayMap<>();
- final Set<String> languages = new ArraySet<>();
- languages.addAll(scores.getEntities());
- languages.addAll(moreTextScores.getEntities());
- for (String language : languages) {
- final float score =
- (subjectTextScoreRatio * scores.getConfidenceScore(language)
- + moreTextScoreRatio
- * moreTextScores.getConfidenceScore(language))
- * penalizeRatio;
- newScores.put(language, score);
- }
- return new EntityConfidence(newScores);
+ if (subject.length() >= minimumTextSize
+ || subject.length() == text.length()
+ || subjectTextScoreRatio * penalizeRatio >= 1) {
+ return scores;
}
- /**
- * Detects languages for the specified text. Only returns languages with score that is higher
- * than or equal to the specified threshold.
- */
- private EntityConfidence detectLanguages(CharSequence text, float threshold) {
- final LangIdModel langId;
- try {
- langId = getLangIdImpl();
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "detectLanguages: Failed to call getLangIdImpl ", e);
- return EntityConfidence.EMPTY;
- }
- final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString());
- final Map<String, Float> languagesMap = new ArrayMap<>();
- for (LangIdModel.LanguageResult langResult : langResults) {
- if (langResult.getScore() >= threshold) {
- languagesMap.put(langResult.getLanguage(), langResult.getScore());
- }
- }
- return new EntityConfidence(languagesMap);
+ final EntityConfidence moreTextScores;
+ if (moreTextScoreRatio >= 0) {
+ // Attempt to grow the detection text to be at least minimumTextSize long.
+ final String moreText = StringUtils.getSubString(text, start, end, minimumTextSize);
+ moreTextScores = detectLanguages(moreText, /* threshold= */ 0f);
+ } else {
+ moreTextScores = EntityConfidence.EMPTY;
}
- private float getLangIdThreshold() {
- try {
- return mSettings.getLangIdThresholdOverride() >= 0
- ? mSettings.getLangIdThresholdOverride()
- : getLangIdImpl().getLangIdThreshold();
- } catch (FileNotFoundException e) {
- final float defaultThreshold = 0.5f;
- TcLog.v(TAG, "Using default foreign language threshold: " + defaultThreshold);
- return defaultThreshold;
- }
+ // Combine the original detection scores with the those returned after including more text.
+ final Map<String, Float> newScores = new ArrayMap<>();
+ final Set<String> languages = new ArraySet<>();
+ languages.addAll(scores.getEntities());
+ languages.addAll(moreTextScores.getEntities());
+ for (String language : languages) {
+ final float score =
+ (subjectTextScoreRatio * scores.getConfidenceScore(language)
+ + moreTextScoreRatio * moreTextScores.getConfidenceScore(language))
+ * penalizeRatio;
+ newScores.put(language, score);
+ }
+ return new EntityConfidence(newScores);
+ }
+
+ /**
+ * Detects languages for the specified text. Only returns languages with score that is higher than
+ * or equal to the specified threshold.
+ */
+ private EntityConfidence detectLanguages(CharSequence text, float threshold) {
+ final LangIdModel langId;
+ try {
+ langId = getLangIdImpl();
+ } catch (FileNotFoundException e) {
+ TcLog.e(TAG, "detectLanguages: Failed to call getLangIdImpl ", e);
+ return EntityConfidence.EMPTY;
+ }
+ final LangIdModel.LanguageResult[] langResults = langId.detectLanguages(text.toString());
+ final Map<String, Float> languagesMap = new ArrayMap<>();
+ for (LangIdModel.LanguageResult langResult : langResults) {
+ if (langResult.getScore() >= threshold) {
+ languagesMap.put(langResult.getLanguage(), langResult.getScore());
+ }
+ }
+ return new EntityConfidence(languagesMap);
+ }
+
+ private float getLangIdThreshold() {
+ try {
+ return settings.getLangIdThresholdOverride() >= 0
+ ? settings.getLangIdThresholdOverride()
+ : getLangIdImpl().getLangIdThreshold();
+ } catch (FileNotFoundException e) {
+ final float defaultThreshold = 0.5f;
+ TcLog.v(TAG, "Using default foreign language threshold: " + defaultThreshold);
+ return defaultThreshold;
+ }
+ }
+
+ void dump(IndentingPrintWriter printWriter) {
+ synchronized (lock) {
+ printWriter.println("TextClassifierImpl:");
+ printWriter.increaseIndent();
+ printWriter.println("Annotator model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : annotatorModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("LangID model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : langIdModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("Actions model file(s):");
+ printWriter.increaseIndent();
+ for (ModelFileManager.ModelFile modelFile : actionsModelFileManager.listModelFiles()) {
+ printWriter.println(modelFile.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.printPair("mFallback", fallback);
+ printWriter.decreaseIndent();
+ printWriter.println();
+ settings.dump(printWriter);
+ if (settings.isUserLanguageProfileEnabled()) {
+ printWriter.println();
+ languageProfileUpdater.dump(printWriter);
+ printWriter.println();
+ languageProfileAnalyzer.dump(printWriter);
+ }
+ }
+ }
+
+ /** Returns the locales string for the current resources configuration. */
+ private String getResourceLocalesString() {
+ try {
+ return context.getResources().getConfiguration().getLocales().toLanguageTags();
+ } catch (NullPointerException e) {
+
+ // NPE is unexpected. Erring on the side of caution.
+ return LocaleList.getDefault().toLanguageTags();
+ }
+ }
+
+ /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
+ private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
+ if (fd == null) {
+ return;
}
- void dump(IndentingPrintWriter printWriter) {
- synchronized (mLock) {
- printWriter.println("TextClassifierImpl:");
- printWriter.increaseIndent();
- printWriter.println("Annotator model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile :
- mAnnotatorModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.println("LangID model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : mLangIdModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.println("Actions model file(s):");
- printWriter.increaseIndent();
- for (ModelFileManager.ModelFile modelFile : mActionsModelFileManager.listModelFiles()) {
- printWriter.println(modelFile.toString());
- }
- printWriter.decreaseIndent();
- printWriter.printPair("mFallback", mFallback);
- printWriter.decreaseIndent();
- printWriter.println();
- mSettings.dump(printWriter);
- if (mSettings.isUserLanguageProfileEnabled()) {
- printWriter.println();
- mLanguageProfileUpdater.dump(printWriter);
- printWriter.println();
- mLanguageProfileAnalyzer.dump(printWriter);
- }
- }
+ try {
+ fd.close();
+ } catch (IOException e) {
+ TcLog.e(TAG, "Error closing file.", e);
}
+ }
- /** Returns the locales string for the current resources configuration. */
- private String getResourceLocalesString() {
- try {
- return mContext.getResources().getConfiguration().getLocales().toLanguageTags();
- } catch (NullPointerException e) {
-
- // NPE is unexpected. Erring on the side of caution.
- return LocaleList.getDefault().toLanguageTags();
- }
+ private static void checkMainThread() {
+ if (Looper.myLooper() == Looper.getMainLooper()) {
+ TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
}
+ }
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
-
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
- }
-
- private static void checkMainThread() {
- if (Looper.myLooper() == Looper.getMainLooper()) {
- TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
- }
- }
-
- private static PendingIntent createPendingIntent(
- @NonNull final Context context, @NonNull final Intent intent, int requestCode) {
- return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
- }
+ private static PendingIntent createPendingIntent(
+ final Context context, final Intent intent, int requestCode) {
+ return PendingIntent.getActivity(
+ context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
index 180bc96..8704644 100644
--- a/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/ClassificationIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,44 +13,41 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.content.Context;
import android.content.Intent;
-
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.R;
-
import com.google.android.textclassifier.AnnotatorModel;
-
import java.time.Instant;
import java.util.List;
+import javax.annotation.Nullable;
-/** @hide */
+/** Generates intents from classification results. */
public interface ClassificationIntentFactory {
- /** Return a list of LabeledIntent from the classification result. */
- List<LabeledIntent> create(
- Context context,
- String text,
- boolean foreignText,
- @Nullable Instant referenceTime,
- @Nullable AnnotatorModel.ClassificationResult classification);
+ /** Return a list of LabeledIntent from the classification result. */
+ List<LabeledIntent> create(
+ Context context,
+ String text,
+ boolean foreignText,
+ @Nullable Instant referenceTime,
+ @Nullable AnnotatorModel.ClassificationResult classification);
- /** Inserts translate action to the list if it is a foreign text. */
- static void insertTranslateAction(List<LabeledIntent> actions, Context context, String text) {
- actions.add(
- new LabeledIntent(
- context.getString(R.string.translate),
- /* titleWithEntity */ null,
- context.getString(R.string.translate_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_TRANSLATE)
- // TODO: Probably better to introduce a "translate" scheme instead
- // of
- // using EXTRA_TEXT.
- .putExtra(Intent.EXTRA_TEXT, text),
- text.hashCode()));
- }
+ /** Inserts translate action to the list if it is a foreign text. */
+ static void insertTranslateAction(List<LabeledIntent> actions, Context context, String text) {
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.translate),
+ /* titleWithEntity */ null,
+ context.getString(R.string.translate_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_TRANSLATE)
+ // TODO: Probably better to introduce a "translate" scheme instead
+ // of
+ // using EXTRA_TEXT.
+ .putExtra(Intent.EXTRA_TEXT, text),
+ text.hashCode()));
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/LabeledIntent.java b/java/src/com/android/textclassifier/intent/LabeledIntent.java
index e51aa82..e678291 100644
--- a/java/src/com/android/textclassifier/intent/LabeledIntent.java
+++ b/java/src/com/android/textclassifier/intent/LabeledIntent.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.app.PendingIntent;
@@ -26,195 +27,183 @@
import android.os.Bundle;
import android.text.TextUtils;
import android.view.textclassifier.TextClassifier;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.ExtrasUtils;
import com.android.textclassifier.R;
import com.android.textclassifier.TcLog;
+import com.google.common.base.Preconditions;
+import javax.annotation.Nullable;
/**
* Helper class to store the information from which RemoteActions are built.
*
* @hide
*/
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public final class LabeledIntent {
- private static final String TAG = "LabeledIntent";
- public static final int DEFAULT_REQUEST_CODE = 0;
- private static final TitleChooser DEFAULT_TITLE_CHOOSER =
- (labeledIntent, resolveInfo) -> {
- if (!TextUtils.isEmpty(labeledIntent.titleWithEntity)) {
- return labeledIntent.titleWithEntity;
- }
- return labeledIntent.titleWithoutEntity;
- };
-
- @Nullable public final String titleWithoutEntity;
- @Nullable public final String titleWithEntity;
- public final String description;
- @Nullable public final String descriptionWithAppName;
- // Do not update this intent.
- public final Intent intent;
- public final int requestCode;
-
- /**
- * Initializes a LabeledIntent.
- *
- * <p>NOTE: {@code requestCode} is required to not be {@link #DEFAULT_REQUEST_CODE} if
- * distinguishing info (e.g. the classified text) is represented in intent extras only. In such
- * circumstances, the request code should represent the distinguishing info (e.g. by generating
- * a hashcode) so that the generated PendingIntent is (somewhat) unique. To be correct, the
- * PendingIntent should be definitely unique but we try a best effort approach that avoids
- * spamming the system with PendingIntents.
- */
- // TODO: Fix the issue mentioned above so the behaviour is correct.
- public LabeledIntent(
- @Nullable String titleWithoutEntity,
- @Nullable String titleWithEntity,
- String description,
- @Nullable String descriptionWithAppName,
- Intent intent,
- int requestCode) {
- if (TextUtils.isEmpty(titleWithEntity) && TextUtils.isEmpty(titleWithoutEntity)) {
- throw new IllegalArgumentException(
- "titleWithEntity and titleWithoutEntity should not be both null");
+ private static final String TAG = "LabeledIntent";
+ public static final int DEFAULT_REQUEST_CODE = 0;
+ private static final TitleChooser DEFAULT_TITLE_CHOOSER =
+ (labeledIntent, resolveInfo) -> {
+ if (!TextUtils.isEmpty(labeledIntent.titleWithEntity)) {
+ return labeledIntent.titleWithEntity;
}
- this.titleWithoutEntity = titleWithoutEntity;
- this.titleWithEntity = titleWithEntity;
- this.description = Preconditions.checkNotNull(description);
- this.descriptionWithAppName = descriptionWithAppName;
- this.intent = Preconditions.checkNotNull(intent);
- this.requestCode = requestCode;
- }
+ return labeledIntent.titleWithoutEntity;
+ };
+ @Nullable public final String titleWithoutEntity;
+ @Nullable public final String titleWithEntity;
+ public final String description;
+ @Nullable public final String descriptionWithAppName;
+ // Do not update this intent.
+ public final Intent intent;
+ public final int requestCode;
+
+ /**
+ * Initializes a LabeledIntent.
+ *
+ * <p>NOTE: {@code requestCode} is required to not be {@link #DEFAULT_REQUEST_CODE} if
+ * distinguishing info (e.g. the classified text) is represented in intent extras only. In such
+ * circumstances, the request code should represent the distinguishing info (e.g. by generating a
+ * hashcode) so that the generated PendingIntent is (somewhat) unique. To be correct, the
+ * PendingIntent should be definitely unique but we try a best effort approach that avoids
+ * spamming the system with PendingIntents.
+ */
+ // TODO: Fix the issue mentioned above so the behaviour is correct.
+ public LabeledIntent(
+ @Nullable String titleWithoutEntity,
+ @Nullable String titleWithEntity,
+ String description,
+ @Nullable String descriptionWithAppName,
+ Intent intent,
+ int requestCode) {
+ if (TextUtils.isEmpty(titleWithEntity) && TextUtils.isEmpty(titleWithoutEntity)) {
+ throw new IllegalArgumentException(
+ "titleWithEntity and titleWithoutEntity should not be both null");
+ }
+ this.titleWithoutEntity = titleWithoutEntity;
+ this.titleWithEntity = titleWithEntity;
+ this.description = Preconditions.checkNotNull(description);
+ this.descriptionWithAppName = descriptionWithAppName;
+ this.intent = Preconditions.checkNotNull(intent);
+ this.requestCode = requestCode;
+ }
+
+ /**
+ * Return the resolved result.
+ *
+ * @param context the context to resolve the result's intent and action
+ * @param titleChooser for choosing an action title
+ * @param textLanguagesBundle containing language detection information
+ */
+ @Nullable
+ public Result resolve(
+ Context context, @Nullable TitleChooser titleChooser, @Nullable Bundle textLanguagesBundle) {
+ final PackageManager pm = context.getPackageManager();
+ final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0);
+
+ if (resolveInfo == null || resolveInfo.activityInfo == null) {
+ TcLog.w(TAG, "resolveInfo or activityInfo is null");
+ return null;
+ }
+ final String packageName = resolveInfo.activityInfo.packageName;
+ final String className = resolveInfo.activityInfo.name;
+ if (packageName == null || className == null) {
+ TcLog.w(TAG, "packageName or className is null");
+ return null;
+ }
+ Intent resolvedIntent = new Intent(intent);
+ resolvedIntent.putExtra(
+ TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER, getFromTextClassifierExtra(textLanguagesBundle));
+ boolean shouldShowIcon = false;
+ Icon icon = null;
+ if (!"android".equals(packageName)) {
+ // We only set the component name when the package name is not resolved to "android"
+ // to workaround a bug that explicit intent with component name == ResolverActivity
+ // can't be launched on keyguard.
+ resolvedIntent.setComponent(new ComponentName(packageName, className));
+ if (resolveInfo.activityInfo.getIconResource() != 0) {
+ icon = Icon.createWithResource(packageName, resolveInfo.activityInfo.getIconResource());
+ shouldShowIcon = true;
+ }
+ }
+ if (icon == null) {
+ // RemoteAction requires that there be an icon.
+ icon = Icon.createWithResource(context, R.drawable.app_icon);
+ }
+ final PendingIntent pendingIntent = createPendingIntent(context, resolvedIntent, requestCode);
+ titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
+ CharSequence title = titleChooser.chooseTitle(this, resolveInfo);
+ if (TextUtils.isEmpty(title)) {
+ TcLog.w(TAG, "Custom titleChooser return null, fallback to the default titleChooser");
+ title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
+ }
+ final RemoteAction action =
+ new RemoteAction(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
+ action.setShouldShowIcon(shouldShowIcon);
+ return new Result(resolvedIntent, action);
+ }
+
+ private String resolveDescription(ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (!TextUtils.isEmpty(descriptionWithAppName)) {
+ // Example string format of descriptionWithAppName: "Use %1$s to open map".
+ String applicationName = getApplicationName(resolveInfo, packageManager);
+ if (!TextUtils.isEmpty(applicationName)) {
+ return String.format(descriptionWithAppName, applicationName);
+ }
+ }
+ return description;
+ }
+
+ @Nullable
+ private String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
+ if (resolveInfo.activityInfo == null) {
+ return null;
+ }
+ if ("android".equals(resolveInfo.activityInfo.packageName)) {
+ return null;
+ }
+ if (resolveInfo.activityInfo.applicationInfo == null) {
+ return null;
+ }
+ return (String) packageManager.getApplicationLabel(resolveInfo.activityInfo.applicationInfo);
+ }
+
+ private Bundle getFromTextClassifierExtra(@Nullable Bundle textLanguagesBundle) {
+ if (textLanguagesBundle != null) {
+ final Bundle bundle = new Bundle();
+ ExtrasUtils.putTextLanguagesExtra(bundle, textLanguagesBundle);
+ return bundle;
+ } else {
+ return Bundle.EMPTY;
+ }
+ }
+
+ private static PendingIntent createPendingIntent(
+ final Context context, final Intent intent, int requestCode) {
+ return PendingIntent.getActivity(
+ context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
+ }
+
+ /** Data class that holds the result. */
+ public static final class Result {
+ public final Intent resolvedIntent;
+ public final RemoteAction remoteAction;
+
+ public Result(Intent resolvedIntent, RemoteAction remoteAction) {
+ this.resolvedIntent = Preconditions.checkNotNull(resolvedIntent);
+ this.remoteAction = Preconditions.checkNotNull(remoteAction);
+ }
+ }
+
+ /**
+ * An object to choose a title from resolved info. If {@code null} is returned, {@link
+ * #titleWithEntity} will be used if it exists, {@link #titleWithoutEntity} otherwise.
+ */
+ public interface TitleChooser {
/**
- * Return the resolved result.
- *
- * @param context the context to resolve the result's intent and action
- * @param titleChooser for choosing an action title
- * @param textLanguagesBundle containing language detection information
+ * Picks a title from a {@link LabeledIntent} by looking into resolved info. {@code resolveInfo}
+ * is guaranteed to have a non-null {@code activityInfo}.
*/
@Nullable
- public Result resolve(
- Context context,
- @Nullable TitleChooser titleChooser,
- @Nullable Bundle textLanguagesBundle) {
- final PackageManager pm = context.getPackageManager();
- final ResolveInfo resolveInfo = pm.resolveActivity(intent, 0);
-
- if (resolveInfo == null || resolveInfo.activityInfo == null) {
- TcLog.w(TAG, "resolveInfo or activityInfo is null");
- return null;
- }
- final String packageName = resolveInfo.activityInfo.packageName;
- final String className = resolveInfo.activityInfo.name;
- if (packageName == null || className == null) {
- TcLog.w(TAG, "packageName or className is null");
- return null;
- }
- Intent resolvedIntent = new Intent(intent);
- resolvedIntent.putExtra(
- TextClassifier.EXTRA_FROM_TEXT_CLASSIFIER,
- getFromTextClassifierExtra(textLanguagesBundle));
- boolean shouldShowIcon = false;
- Icon icon = null;
- if (!"android".equals(packageName)) {
- // We only set the component name when the package name is not resolved to "android"
- // to workaround a bug that explicit intent with component name == ResolverActivity
- // can't be launched on keyguard.
- resolvedIntent.setComponent(new ComponentName(packageName, className));
- if (resolveInfo.activityInfo.getIconResource() != 0) {
- icon =
- Icon.createWithResource(
- packageName, resolveInfo.activityInfo.getIconResource());
- shouldShowIcon = true;
- }
- }
- if (icon == null) {
- // RemoteAction requires that there be an icon.
- icon = Icon.createWithResource(context, R.drawable.app_icon);
- }
- final PendingIntent pendingIntent =
- createPendingIntent(context, resolvedIntent, requestCode);
- titleChooser = titleChooser == null ? DEFAULT_TITLE_CHOOSER : titleChooser;
- CharSequence title = titleChooser.chooseTitle(this, resolveInfo);
- if (TextUtils.isEmpty(title)) {
- TcLog.w(TAG, "Custom titleChooser return null, fallback to the default titleChooser");
- title = DEFAULT_TITLE_CHOOSER.chooseTitle(this, resolveInfo);
- }
- final RemoteAction action =
- new RemoteAction(icon, title, resolveDescription(resolveInfo, pm), pendingIntent);
- action.setShouldShowIcon(shouldShowIcon);
- return new Result(resolvedIntent, action);
- }
-
- private String resolveDescription(ResolveInfo resolveInfo, PackageManager packageManager) {
- if (!TextUtils.isEmpty(descriptionWithAppName)) {
- // Example string format of descriptionWithAppName: "Use %1$s to open map".
- String applicationName = getApplicationName(resolveInfo, packageManager);
- if (!TextUtils.isEmpty(applicationName)) {
- return String.format(descriptionWithAppName, applicationName);
- }
- }
- return description;
- }
-
- @Nullable
- private String getApplicationName(ResolveInfo resolveInfo, PackageManager packageManager) {
- if (resolveInfo.activityInfo == null) {
- return null;
- }
- if ("android".equals(resolveInfo.activityInfo.packageName)) {
- return null;
- }
- if (resolveInfo.activityInfo.applicationInfo == null) {
- return null;
- }
- return (String)
- packageManager.getApplicationLabel(resolveInfo.activityInfo.applicationInfo);
- }
-
- private Bundle getFromTextClassifierExtra(@Nullable Bundle textLanguagesBundle) {
- if (textLanguagesBundle != null) {
- final Bundle bundle = new Bundle();
- ExtrasUtils.putTextLanguagesExtra(bundle, textLanguagesBundle);
- return bundle;
- } else {
- return Bundle.EMPTY;
- }
- }
-
- private static PendingIntent createPendingIntent(
- @NonNull final Context context, @NonNull final Intent intent, int requestCode) {
- return PendingIntent.getActivity(
- context, requestCode, intent, PendingIntent.FLAG_UPDATE_CURRENT);
- }
-
- /** Data class that holds the result. */
- public static final class Result {
- public final Intent resolvedIntent;
- public final RemoteAction remoteAction;
-
- public Result(Intent resolvedIntent, RemoteAction remoteAction) {
- this.resolvedIntent = Preconditions.checkNotNull(resolvedIntent);
- this.remoteAction = Preconditions.checkNotNull(remoteAction);
- }
- }
-
- /**
- * An object to choose a title from resolved info. If {@code null} is returned, {@link
- * #titleWithEntity} will be used if it exists, {@link #titleWithoutEntity} otherwise.
- */
- public interface TitleChooser {
- /**
- * Picks a title from a {@link LabeledIntent} by looking into resolved info. {@code
- * resolveInfo} is guaranteed to have a non-null {@code activityInfo}.
- */
- @Nullable
- CharSequence chooseTitle(LabeledIntent labeledIntent, ResolveInfo resolveInfo);
- }
+ CharSequence chooseTitle(LabeledIntent labeledIntent, ResolveInfo resolveInfo);
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
index 4bac39a..c58df76 100644
--- a/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/LegacyClassificationIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import static java.time.temporal.ChronoUnit.MILLIS;
@@ -28,15 +29,9 @@
import android.provider.CalendarContract;
import android.provider.ContactsContract;
import android.view.textclassifier.TextClassifier;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.R;
import com.android.textclassifier.TcLog;
-
import com.google.android.textclassifier.AnnotatorModel;
-
import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.time.Instant;
@@ -44,256 +39,239 @@
import java.util.List;
import java.util.Locale;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
-/**
- * Creates intents based on the classification type.
- *
- * @hide
- */
+/** Creates intents based on the classification type. */
// TODO: Consider to support {@code descriptionWithAppName}.
public final class LegacyClassificationIntentFactory implements ClassificationIntentFactory {
- private static final String TAG = "LegacyClassificationIntentFactory";
- private static final long MIN_EVENT_FUTURE_MILLIS = TimeUnit.MINUTES.toMillis(5);
- private static final long DEFAULT_EVENT_DURATION = TimeUnit.HOURS.toMillis(1);
+ private static final String TAG = "LegacyIntentFactory";
+ private static final long MIN_EVENT_FUTURE_MILLIS = TimeUnit.MINUTES.toMillis(5);
+ private static final long DEFAULT_EVENT_DURATION = TimeUnit.HOURS.toMillis(1);
- // Sync with TextClassifier.TYPE_DICTIONARY.
- private static final String TYPE_DICTIONARY = "dictionary";
+ // Sync with TextClassifier.TYPE_DICTIONARY.
+ private static final String TYPE_DICTIONARY = "dictionary";
- @NonNull
- @Override
- public List<LabeledIntent> create(
- Context context,
- String text,
- boolean foreignText,
- @Nullable Instant referenceTime,
- AnnotatorModel.ClassificationResult classification) {
- final String type =
- classification != null
- ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
- : "";
- text = text.trim();
- final List<LabeledIntent> actions;
- switch (type) {
- case TextClassifier.TYPE_EMAIL:
- actions = createForEmail(context, text);
- break;
- case TextClassifier.TYPE_PHONE:
- actions = createForPhone(context, text);
- break;
- case TextClassifier.TYPE_ADDRESS:
- actions = createForAddress(context, text);
- break;
- case TextClassifier.TYPE_URL:
- actions = createForUrl(context, text);
- break;
- case TextClassifier.TYPE_DATE: // fall through
- case TextClassifier.TYPE_DATE_TIME:
- if (classification.getDatetimeResult() != null) {
- final Instant parsedTime =
- Instant.ofEpochMilli(classification.getDatetimeResult().getTimeMsUtc());
- actions = createForDatetime(context, type, referenceTime, parsedTime);
- } else {
- actions = new ArrayList<>();
- }
- break;
- case TextClassifier.TYPE_FLIGHT_NUMBER:
- actions = createForFlight(context, text);
- break;
- // case TextClassifier.TYPE_DICTIONARY:
- case TYPE_DICTIONARY:
- actions = createForDictionary(context, text);
- break;
- default:
- actions = new ArrayList<>();
- break;
+ @Override
+ public List<LabeledIntent> create(
+ Context context,
+ String text,
+ boolean foreignText,
+ @Nullable Instant referenceTime,
+ AnnotatorModel.ClassificationResult classification) {
+ final String type =
+ classification != null
+ ? classification.getCollection().trim().toLowerCase(Locale.ENGLISH)
+ : "";
+ text = text.trim();
+ final List<LabeledIntent> actions;
+ switch (type) {
+ case TextClassifier.TYPE_EMAIL:
+ actions = createForEmail(context, text);
+ break;
+ case TextClassifier.TYPE_PHONE:
+ actions = createForPhone(context, text);
+ break;
+ case TextClassifier.TYPE_ADDRESS:
+ actions = createForAddress(context, text);
+ break;
+ case TextClassifier.TYPE_URL:
+ actions = createForUrl(context, text);
+ break;
+ case TextClassifier.TYPE_DATE: // fall through
+ case TextClassifier.TYPE_DATE_TIME:
+ if (classification.getDatetimeResult() != null) {
+ final Instant parsedTime =
+ Instant.ofEpochMilli(classification.getDatetimeResult().getTimeMsUtc());
+ actions = createForDatetime(context, type, referenceTime, parsedTime);
+ } else {
+ actions = new ArrayList<>();
}
- if (foreignText) {
- ClassificationIntentFactory.insertTranslateAction(actions, context, text);
- }
- return actions;
+ break;
+ case TextClassifier.TYPE_FLIGHT_NUMBER:
+ actions = createForFlight(context, text);
+ break;
+ // case TextClassifier.TYPE_DICTIONARY:
+ case TYPE_DICTIONARY:
+ actions = createForDictionary(context, text);
+ break;
+ default:
+ actions = new ArrayList<>();
+ break;
}
+ if (foreignText) {
+ ClassificationIntentFactory.insertTranslateAction(actions, context, text);
+ }
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForEmail(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.email),
- /* titleWithEntity */ null,
- context.getString(R.string.email_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_SENDTO)
- .setData(Uri.parse(String.format("mailto:%s", text))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- actions.add(
- new LabeledIntent(
- context.getString(R.string.add_contact),
- /* titleWithEntity */ null,
- context.getString(R.string.add_contact_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_INSERT_OR_EDIT)
- .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
- .putExtra(ContactsContract.Intents.Insert.EMAIL, text),
- text.hashCode()));
- return actions;
- }
+ private static List<LabeledIntent> createForEmail(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.email),
+ /* titleWithEntity */ null,
+ context.getString(R.string.email_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_SENDTO).setData(Uri.parse(String.format("mailto:%s", text))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.add_contact),
+ /* titleWithEntity */ null,
+ context.getString(R.string.add_contact_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_INSERT_OR_EDIT)
+ .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
+ .putExtra(ContactsContract.Intents.Insert.EMAIL, text),
+ text.hashCode()));
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForPhone(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- final UserManager userManager = context.getSystemService(UserManager.class);
- final Bundle userRestrictions =
- userManager != null ? userManager.getUserRestrictions() : new Bundle();
- if (!userRestrictions.getBoolean(UserManager.DISALLOW_OUTGOING_CALLS, false)) {
- actions.add(
- new LabeledIntent(
- context.getString(R.string.dial),
- /* titleWithEntity */ null,
- context.getString(R.string.dial_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_DIAL)
- .setData(Uri.parse(String.format("tel:%s", text))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- }
- actions.add(
- new LabeledIntent(
- context.getString(R.string.add_contact),
- /* titleWithEntity */ null,
- context.getString(R.string.add_contact_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_INSERT_OR_EDIT)
- .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
- .putExtra(ContactsContract.Intents.Insert.PHONE, text),
- text.hashCode()));
- if (!userRestrictions.getBoolean(UserManager.DISALLOW_SMS, false)) {
- actions.add(
- new LabeledIntent(
- context.getString(R.string.sms),
- /* titleWithEntity */ null,
- context.getString(R.string.sms_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_SENDTO)
- .setData(Uri.parse(String.format("smsto:%s", text))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- }
- return actions;
+ private static List<LabeledIntent> createForPhone(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ final UserManager userManager = context.getSystemService(UserManager.class);
+ final Bundle userRestrictions =
+ userManager != null ? userManager.getUserRestrictions() : new Bundle();
+ if (!userRestrictions.getBoolean(UserManager.DISALLOW_OUTGOING_CALLS, false)) {
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.dial),
+ /* titleWithEntity */ null,
+ context.getString(R.string.dial_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_DIAL).setData(Uri.parse(String.format("tel:%s", text))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
}
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.add_contact),
+ /* titleWithEntity */ null,
+ context.getString(R.string.add_contact_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_INSERT_OR_EDIT)
+ .setType(ContactsContract.Contacts.CONTENT_ITEM_TYPE)
+ .putExtra(ContactsContract.Intents.Insert.PHONE, text),
+ text.hashCode()));
+ if (!userRestrictions.getBoolean(UserManager.DISALLOW_SMS, false)) {
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.sms),
+ /* titleWithEntity */ null,
+ context.getString(R.string.sms_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_SENDTO).setData(Uri.parse(String.format("smsto:%s", text))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ }
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForAddress(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- try {
- final String encText = URLEncoder.encode(text, "UTF-8");
- actions.add(
- new LabeledIntent(
- context.getString(R.string.map),
- /* titleWithEntity */ null,
- context.getString(R.string.map_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_VIEW)
- .setData(Uri.parse(String.format("geo:0,0?q=%s", encText))),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- } catch (UnsupportedEncodingException e) {
- TcLog.e(TAG, "Could not encode address", e);
- }
- return actions;
+ private static List<LabeledIntent> createForAddress(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ try {
+ final String encText = URLEncoder.encode(text, "UTF-8");
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.map),
+ /* titleWithEntity */ null,
+ context.getString(R.string.map_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_VIEW)
+ .setData(Uri.parse(String.format("geo:0,0?q=%s", encText))),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ } catch (UnsupportedEncodingException e) {
+ TcLog.e(TAG, "Could not encode address", e);
}
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForUrl(Context context, String text) {
- if (Uri.parse(text).getScheme() == null) {
- text = "http://" + text;
- }
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.browse),
- /* titleWithEntity */ null,
- context.getString(R.string.browse_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_VIEW)
- .setDataAndNormalize(Uri.parse(text))
- .putExtra(Browser.EXTRA_APPLICATION_ID, context.getPackageName()),
- LabeledIntent.DEFAULT_REQUEST_CODE));
- return actions;
+ private static List<LabeledIntent> createForUrl(Context context, String text) {
+ if (Uri.parse(text).getScheme() == null) {
+ text = "http://" + text;
}
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.browse),
+ /* titleWithEntity */ null,
+ context.getString(R.string.browse_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_VIEW)
+ .setDataAndNormalize(Uri.parse(text))
+ .putExtra(Browser.EXTRA_APPLICATION_ID, context.getPackageName()),
+ LabeledIntent.DEFAULT_REQUEST_CODE));
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForDatetime(
- Context context, String type, @Nullable Instant referenceTime, Instant parsedTime) {
- if (referenceTime == null) {
- // If no reference time was given, use now.
- referenceTime = Instant.now();
- }
- List<LabeledIntent> actions = new ArrayList<>();
- actions.add(createCalendarViewIntent(context, parsedTime));
- final long millisUntilEvent = referenceTime.until(parsedTime, MILLIS);
- if (millisUntilEvent > MIN_EVENT_FUTURE_MILLIS) {
- actions.add(createCalendarCreateEventIntent(context, parsedTime, type));
- }
- return actions;
+ private static List<LabeledIntent> createForDatetime(
+ Context context, String type, @Nullable Instant referenceTime, Instant parsedTime) {
+ if (referenceTime == null) {
+ // If no reference time was given, use now.
+ referenceTime = Instant.now();
}
+ List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(createCalendarViewIntent(context, parsedTime));
+ final long millisUntilEvent = referenceTime.until(parsedTime, MILLIS);
+ if (millisUntilEvent > MIN_EVENT_FUTURE_MILLIS) {
+ actions.add(createCalendarCreateEventIntent(context, parsedTime, type));
+ }
+ return actions;
+ }
- @NonNull
- private static List<LabeledIntent> createForFlight(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.view_flight),
- /* titleWithEntity */ null,
- context.getString(R.string.view_flight_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_WEB_SEARCH).putExtra(SearchManager.QUERY, text),
- text.hashCode()));
- return actions;
- }
+ private static List<LabeledIntent> createForFlight(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.view_flight),
+ /* titleWithEntity */ null,
+ context.getString(R.string.view_flight_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_WEB_SEARCH).putExtra(SearchManager.QUERY, text),
+ text.hashCode()));
+ return actions;
+ }
- @NonNull
- private static LabeledIntent createCalendarViewIntent(Context context, Instant parsedTime) {
- Uri.Builder builder = CalendarContract.CONTENT_URI.buildUpon();
- builder.appendPath("time");
- ContentUris.appendId(builder, parsedTime.toEpochMilli());
- return new LabeledIntent(
- context.getString(R.string.view_calendar),
- /* titleWithEntity */ null,
- context.getString(R.string.view_calendar_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_VIEW).setData(builder.build()),
- LabeledIntent.DEFAULT_REQUEST_CODE);
- }
+ private static LabeledIntent createCalendarViewIntent(Context context, Instant parsedTime) {
+ Uri.Builder builder = CalendarContract.CONTENT_URI.buildUpon();
+ builder.appendPath("time");
+ ContentUris.appendId(builder, parsedTime.toEpochMilli());
+ return new LabeledIntent(
+ context.getString(R.string.view_calendar),
+ /* titleWithEntity */ null,
+ context.getString(R.string.view_calendar_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_VIEW).setData(builder.build()),
+ LabeledIntent.DEFAULT_REQUEST_CODE);
+ }
- @NonNull
- private static LabeledIntent createCalendarCreateEventIntent(
- Context context, Instant parsedTime, String type) {
- final boolean isAllDay = TextClassifier.TYPE_DATE.equals(type);
- return new LabeledIntent(
- context.getString(R.string.add_calendar_event),
- /* titleWithEntity */ null,
- context.getString(R.string.add_calendar_event_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_INSERT)
- .setData(CalendarContract.Events.CONTENT_URI)
- .putExtra(CalendarContract.EXTRA_EVENT_ALL_DAY, isAllDay)
- .putExtra(
- CalendarContract.EXTRA_EVENT_BEGIN_TIME, parsedTime.toEpochMilli())
- .putExtra(
- CalendarContract.EXTRA_EVENT_END_TIME,
- parsedTime.toEpochMilli() + DEFAULT_EVENT_DURATION),
- parsedTime.hashCode());
- }
+ private static LabeledIntent createCalendarCreateEventIntent(
+ Context context, Instant parsedTime, String type) {
+ final boolean isAllDay = TextClassifier.TYPE_DATE.equals(type);
+ return new LabeledIntent(
+ context.getString(R.string.add_calendar_event),
+ /* titleWithEntity */ null,
+ context.getString(R.string.add_calendar_event_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_INSERT)
+ .setData(CalendarContract.Events.CONTENT_URI)
+ .putExtra(CalendarContract.EXTRA_EVENT_ALL_DAY, isAllDay)
+ .putExtra(CalendarContract.EXTRA_EVENT_BEGIN_TIME, parsedTime.toEpochMilli())
+ .putExtra(
+ CalendarContract.EXTRA_EVENT_END_TIME,
+ parsedTime.toEpochMilli() + DEFAULT_EVENT_DURATION),
+ parsedTime.hashCode());
+ }
- @NonNull
- private static List<LabeledIntent> createForDictionary(Context context, String text) {
- final List<LabeledIntent> actions = new ArrayList<>();
- actions.add(
- new LabeledIntent(
- context.getString(R.string.define),
- /* titleWithEntity */ null,
- context.getString(R.string.define_desc),
- /* descriptionWithAppName */ null,
- new Intent(Intent.ACTION_DEFINE).putExtra(Intent.EXTRA_TEXT, text),
- text.hashCode()));
- return actions;
- }
+ private static List<LabeledIntent> createForDictionary(Context context, String text) {
+ final List<LabeledIntent> actions = new ArrayList<>();
+ actions.add(
+ new LabeledIntent(
+ context.getString(R.string.define),
+ /* titleWithEntity */ null,
+ context.getString(R.string.define_desc),
+ /* descriptionWithAppName */ null,
+ new Intent(Intent.ACTION_DEFINE).putExtra(Intent.EXTRA_TEXT, text),
+ text.hashCode()));
+ return actions;
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java b/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
index 53f5bf4..f5ef577 100644
--- a/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/TemplateClassificationIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,69 +13,60 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.content.Context;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.TcLog;
-
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.RemoteActionTemplate;
-
+import com.google.common.base.Preconditions;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
+import javax.annotation.Nullable;
/**
* Creates intents based on {@link RemoteActionTemplate} objects for a ClassificationResult.
*
* @hide
*/
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public final class TemplateClassificationIntentFactory implements ClassificationIntentFactory {
- private static final String TAG = "TemplateClassificationIntentFactory";
- private final TemplateIntentFactory mTemplateIntentFactory;
- private final ClassificationIntentFactory mFallback;
+ private static final String TAG = "TemplateIntentFactory";
+ private final TemplateIntentFactory templateIntentFactory;
+ private final ClassificationIntentFactory fallback;
- public TemplateClassificationIntentFactory(
- TemplateIntentFactory templateIntentFactory, ClassificationIntentFactory fallback) {
- mTemplateIntentFactory = Preconditions.checkNotNull(templateIntentFactory);
- mFallback = Preconditions.checkNotNull(fallback);
- }
+ public TemplateClassificationIntentFactory(
+ TemplateIntentFactory templateIntentFactory, ClassificationIntentFactory fallback) {
+ this.templateIntentFactory = Preconditions.checkNotNull(templateIntentFactory);
+ this.fallback = Preconditions.checkNotNull(fallback);
+ }
- /**
- * Returns a list of {@link LabeledIntent} that are constructed from the classification result.
- */
- @NonNull
- @Override
- public List<LabeledIntent> create(
- Context context,
- String text,
- boolean foreignText,
- @Nullable Instant referenceTime,
- @Nullable AnnotatorModel.ClassificationResult classification) {
- if (classification == null) {
- return Collections.emptyList();
- }
- RemoteActionTemplate[] remoteActionTemplates = classification.getRemoteActionTemplates();
- if (remoteActionTemplates == null) {
- // RemoteActionTemplate is missing, fallback.
- TcLog.w(
- TAG,
- "RemoteActionTemplate is missing, fallback to"
- + " LegacyClassificationIntentFactory.");
- return mFallback.create(context, text, foreignText, referenceTime, classification);
- }
- final List<LabeledIntent> labeledIntents =
- mTemplateIntentFactory.create(remoteActionTemplates);
- if (foreignText) {
- ClassificationIntentFactory.insertTranslateAction(labeledIntents, context, text.trim());
- }
- return labeledIntents;
+ /**
+ * Returns a list of {@link LabeledIntent} that are constructed from the classification result.
+ */
+ @Override
+ public List<LabeledIntent> create(
+ Context context,
+ String text,
+ boolean foreignText,
+ @Nullable Instant referenceTime,
+ @Nullable AnnotatorModel.ClassificationResult classification) {
+ if (classification == null) {
+ return Collections.emptyList();
}
+ RemoteActionTemplate[] remoteActionTemplates = classification.getRemoteActionTemplates();
+ if (remoteActionTemplates == null) {
+ // RemoteActionTemplate is missing, fallback.
+ TcLog.w(
+ TAG,
+ "RemoteActionTemplate is missing, fallback to" + " LegacyClassificationIntentFactory.");
+ return fallback.create(context, text, foreignText, referenceTime, classification);
+ }
+ final List<LabeledIntent> labeledIntents = templateIntentFactory.create(remoteActionTemplates);
+ if (foreignText) {
+ ClassificationIntentFactory.insertTranslateAction(labeledIntents, context, text.trim());
+ }
+ return labeledIntents;
+ }
}
diff --git a/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java b/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
index 57f608a..718c988 100644
--- a/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
+++ b/java/src/com/android/textclassifier/intent/TemplateIntentFactory.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,144 +13,136 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.intent;
import android.content.Intent;
import android.net.Uri;
import android.os.Bundle;
import android.text.TextUtils;
-
-import androidx.annotation.NonNull;
-import androidx.annotation.Nullable;
-import androidx.annotation.VisibleForTesting;
-
import com.android.textclassifier.TcLog;
-
import com.google.android.textclassifier.NamedVariant;
import com.google.android.textclassifier.RemoteActionTemplate;
-
import java.util.ArrayList;
import java.util.List;
+import javax.annotation.Nullable;
/**
* Creates intents based on {@link RemoteActionTemplate} objects.
*
* @hide
*/
-@VisibleForTesting(otherwise = VisibleForTesting.PACKAGE_PRIVATE)
public final class TemplateIntentFactory {
- private static final String TAG = "TemplateIntentFactory";
+ private static final String TAG = "TemplateIntentFactory";
- /** Constructs and returns a list of {@link LabeledIntent} based on the given templates. */
- @Nullable
- public List<LabeledIntent> create(@NonNull RemoteActionTemplate[] remoteActionTemplates) {
- if (remoteActionTemplates.length == 0) {
- return new ArrayList<>();
- }
- final List<LabeledIntent> labeledIntents = new ArrayList<>();
- for (RemoteActionTemplate remoteActionTemplate : remoteActionTemplates) {
- if (!isValidTemplate(remoteActionTemplate)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate skipped.");
- continue;
- }
- labeledIntents.add(
- new LabeledIntent(
- remoteActionTemplate.titleWithoutEntity,
- remoteActionTemplate.titleWithEntity,
- remoteActionTemplate.description,
- remoteActionTemplate.descriptionWithAppName,
- createIntent(remoteActionTemplate),
- remoteActionTemplate.requestCode == null
- ? LabeledIntent.DEFAULT_REQUEST_CODE
- : remoteActionTemplate.requestCode));
- }
- return labeledIntents;
+ /** Constructs and returns a list of {@link LabeledIntent} based on the given templates. */
+ @Nullable
+ public List<LabeledIntent> create(RemoteActionTemplate[] remoteActionTemplates) {
+ if (remoteActionTemplates.length == 0) {
+ return new ArrayList<>();
}
+ final List<LabeledIntent> labeledIntents = new ArrayList<>();
+ for (RemoteActionTemplate remoteActionTemplate : remoteActionTemplates) {
+ if (!isValidTemplate(remoteActionTemplate)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate skipped.");
+ continue;
+ }
+ labeledIntents.add(
+ new LabeledIntent(
+ remoteActionTemplate.titleWithoutEntity,
+ remoteActionTemplate.titleWithEntity,
+ remoteActionTemplate.description,
+ remoteActionTemplate.descriptionWithAppName,
+ createIntent(remoteActionTemplate),
+ remoteActionTemplate.requestCode == null
+ ? LabeledIntent.DEFAULT_REQUEST_CODE
+ : remoteActionTemplate.requestCode));
+ }
+ return labeledIntents;
+ }
- private static boolean isValidTemplate(@Nullable RemoteActionTemplate remoteActionTemplate) {
- if (remoteActionTemplate == null) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: is null");
- return false;
- }
- if (TextUtils.isEmpty(remoteActionTemplate.titleWithEntity)
- && TextUtils.isEmpty(remoteActionTemplate.titleWithoutEntity)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: title is null");
- return false;
- }
- if (TextUtils.isEmpty(remoteActionTemplate.description)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: description is null");
- return false;
- }
- if (!TextUtils.isEmpty(remoteActionTemplate.packageName)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: package name is set");
- return false;
- }
- if (TextUtils.isEmpty(remoteActionTemplate.action)) {
- TcLog.w(TAG, "Invalid RemoteActionTemplate: intent action not set");
- return false;
- }
- return true;
+ private static boolean isValidTemplate(@Nullable RemoteActionTemplate remoteActionTemplate) {
+ if (remoteActionTemplate == null) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: is null");
+ return false;
}
+ if (TextUtils.isEmpty(remoteActionTemplate.titleWithEntity)
+ && TextUtils.isEmpty(remoteActionTemplate.titleWithoutEntity)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: title is null");
+ return false;
+ }
+ if (TextUtils.isEmpty(remoteActionTemplate.description)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: description is null");
+ return false;
+ }
+ if (!TextUtils.isEmpty(remoteActionTemplate.packageName)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: package name is set");
+ return false;
+ }
+ if (TextUtils.isEmpty(remoteActionTemplate.action)) {
+ TcLog.w(TAG, "Invalid RemoteActionTemplate: intent action not set");
+ return false;
+ }
+ return true;
+ }
- private static Intent createIntent(RemoteActionTemplate remoteActionTemplate) {
- final Intent intent = new Intent(remoteActionTemplate.action);
- final Uri uri =
- TextUtils.isEmpty(remoteActionTemplate.data)
- ? null
- : Uri.parse(remoteActionTemplate.data).normalizeScheme();
- final String type =
- TextUtils.isEmpty(remoteActionTemplate.type)
- ? null
- : Intent.normalizeMimeType(remoteActionTemplate.type);
- intent.setDataAndType(uri, type);
- intent.setFlags(remoteActionTemplate.flags == null ? 0 : remoteActionTemplate.flags);
- if (remoteActionTemplate.category != null) {
- for (String category : remoteActionTemplate.category) {
- if (category != null) {
- intent.addCategory(category);
- }
- }
+ private static Intent createIntent(RemoteActionTemplate remoteActionTemplate) {
+ final Intent intent = new Intent(remoteActionTemplate.action);
+ final Uri uri =
+ TextUtils.isEmpty(remoteActionTemplate.data)
+ ? null
+ : Uri.parse(remoteActionTemplate.data).normalizeScheme();
+ final String type =
+ TextUtils.isEmpty(remoteActionTemplate.type)
+ ? null
+ : Intent.normalizeMimeType(remoteActionTemplate.type);
+ intent.setDataAndType(uri, type);
+ intent.setFlags(remoteActionTemplate.flags == null ? 0 : remoteActionTemplate.flags);
+ if (remoteActionTemplate.category != null) {
+ for (String category : remoteActionTemplate.category) {
+ if (category != null) {
+ intent.addCategory(category);
}
- intent.putExtras(nameVariantsToBundle(remoteActionTemplate.extras));
- return intent;
+ }
}
+ intent.putExtras(nameVariantsToBundle(remoteActionTemplate.extras));
+ return intent;
+ }
- /** Converts an array of {@link NamedVariant} to a Bundle and returns it. */
- public static Bundle nameVariantsToBundle(@Nullable NamedVariant[] namedVariants) {
- if (namedVariants == null) {
- return Bundle.EMPTY;
- }
- Bundle bundle = new Bundle();
- for (NamedVariant namedVariant : namedVariants) {
- if (namedVariant == null) {
- continue;
- }
- switch (namedVariant.getType()) {
- case NamedVariant.TYPE_INT:
- bundle.putInt(namedVariant.getName(), namedVariant.getInt());
- break;
- case NamedVariant.TYPE_LONG:
- bundle.putLong(namedVariant.getName(), namedVariant.getLong());
- break;
- case NamedVariant.TYPE_FLOAT:
- bundle.putFloat(namedVariant.getName(), namedVariant.getFloat());
- break;
- case NamedVariant.TYPE_DOUBLE:
- bundle.putDouble(namedVariant.getName(), namedVariant.getDouble());
- break;
- case NamedVariant.TYPE_BOOL:
- bundle.putBoolean(namedVariant.getName(), namedVariant.getBool());
- break;
- case NamedVariant.TYPE_STRING:
- bundle.putString(namedVariant.getName(), namedVariant.getString());
- break;
- default:
- TcLog.w(
- TAG,
- "Unsupported type found in nameVariantsToBundle : "
- + namedVariant.getType());
- }
- }
- return bundle;
+ /** Converts an array of {@link NamedVariant} to a Bundle and returns it. */
+ public static Bundle nameVariantsToBundle(@Nullable NamedVariant[] namedVariants) {
+ if (namedVariants == null) {
+ return Bundle.EMPTY;
}
+ Bundle bundle = new Bundle();
+ for (NamedVariant namedVariant : namedVariants) {
+ if (namedVariant == null) {
+ continue;
+ }
+ switch (namedVariant.getType()) {
+ case NamedVariant.TYPE_INT:
+ bundle.putInt(namedVariant.getName(), namedVariant.getInt());
+ break;
+ case NamedVariant.TYPE_LONG:
+ bundle.putLong(namedVariant.getName(), namedVariant.getLong());
+ break;
+ case NamedVariant.TYPE_FLOAT:
+ bundle.putFloat(namedVariant.getName(), namedVariant.getFloat());
+ break;
+ case NamedVariant.TYPE_DOUBLE:
+ bundle.putDouble(namedVariant.getName(), namedVariant.getDouble());
+ break;
+ case NamedVariant.TYPE_BOOL:
+ bundle.putBoolean(namedVariant.getName(), namedVariant.getBool());
+ break;
+ case NamedVariant.TYPE_STRING:
+ bundle.putString(namedVariant.getName(), namedVariant.getString());
+ break;
+ default:
+ TcLog.w(
+ TAG, "Unsupported type found in nameVariantsToBundle : " + namedVariant.getType());
+ }
+ }
+ return bundle;
+ }
}
diff --git a/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java b/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java
index d06104c..5da8e93 100644
--- a/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java
+++ b/java/src/com/android/textclassifier/logging/GenerateLinksLogger.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2017 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,128 +19,126 @@
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextClassifierEvent;
import android.view.textclassifier.TextLinks;
-
-import androidx.annotation.Nullable;
import androidx.collection.ArrayMap;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.TcLog;
import com.android.textclassifier.TextClassifierStatsLog;
-
+import com.google.common.base.Preconditions;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.UUID;
+import javax.annotation.Nullable;
/** A helper for logging calls to generateLinks. */
public final class GenerateLinksLogger {
- private static final String LOG_TAG = "GenerateLinksLogger";
+ private static final String LOG_TAG = "GenerateLinksLogger";
- private final Random mRng;
- private final int mSampleRate;
+ private final Random rng;
+ private final int sampleRate;
- /**
- * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
- * chance that a call to logGenerateLinks results in an event being written). To write all
- * events, pass 1.
- */
- public GenerateLinksLogger(int sampleRate) {
- mSampleRate = sampleRate;
- mRng = new Random();
+ /**
+ * @param sampleRate the rate at which log events are written. (e.g. 100 means there is a 0.01
+ * chance that a call to logGenerateLinks results in an event being written). To write all
+ * events, pass 1.
+ */
+ public GenerateLinksLogger(int sampleRate) {
+ this.sampleRate = sampleRate;
+ rng = new Random();
+ }
+
+ /** Logs statistics about a call to generateLinks. */
+ public void logGenerateLinks(
+ CharSequence text, TextLinks links, String callingPackageName, long latencyMs) {
+ Preconditions.checkNotNull(text);
+ Preconditions.checkNotNull(links);
+ Preconditions.checkNotNull(callingPackageName);
+ if (!shouldLog()) {
+ return;
}
- /** Logs statistics about a call to generateLinks. */
- public void logGenerateLinks(
- CharSequence text, TextLinks links, String callingPackageName, long latencyMs) {
- Preconditions.checkNotNull(text);
- Preconditions.checkNotNull(links);
- Preconditions.checkNotNull(callingPackageName);
- if (!shouldLog()) {
- return;
- }
-
- // Always populate the total stats, and per-entity stats for each entity type detected.
- final LinkifyStats totalStats = new LinkifyStats();
- final Map<String, LinkifyStats> perEntityTypeStats = new ArrayMap<>();
- for (TextLinks.TextLink link : links.getLinks()) {
- if (link.getEntityCount() == 0) continue;
- final String entityType = link.getEntity(0);
- if (entityType == null
- || TextClassifier.TYPE_OTHER.equals(entityType)
- || TextClassifier.TYPE_UNKNOWN.equals(entityType)) {
- continue;
- }
- totalStats.countLink(link);
- perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link);
- }
-
- final String callId = UUID.randomUUID().toString();
- writeStats(callId, callingPackageName, null, totalStats, text, latencyMs);
- for (Map.Entry<String, LinkifyStats> entry : perEntityTypeStats.entrySet()) {
- writeStats(
- callId, callingPackageName, entry.getKey(), entry.getValue(), text, latencyMs);
- }
+ // Always populate the total stats, and per-entity stats for each entity type detected.
+ final LinkifyStats totalStats = new LinkifyStats();
+ final Map<String, LinkifyStats> perEntityTypeStats = new ArrayMap<>();
+ for (TextLinks.TextLink link : links.getLinks()) {
+ if (link.getEntityCount() == 0) {
+ continue;
+ }
+ final String entityType = link.getEntity(0);
+ if (entityType == null
+ || TextClassifier.TYPE_OTHER.equals(entityType)
+ || TextClassifier.TYPE_UNKNOWN.equals(entityType)) {
+ continue;
+ }
+ totalStats.countLink(link);
+ perEntityTypeStats.computeIfAbsent(entityType, k -> new LinkifyStats()).countLink(link);
}
- /**
- * Returns whether this particular event should be logged.
- *
- * <p>Sampling is used to reduce the amount of logging data generated.
- */
- private boolean shouldLog() {
- if (mSampleRate <= 1) {
- return true;
- } else {
- return mRng.nextInt(mSampleRate) == 0;
- }
+ final String callId = UUID.randomUUID().toString();
+ writeStats(callId, callingPackageName, null, totalStats, text, latencyMs);
+ for (Map.Entry<String, LinkifyStats> entry : perEntityTypeStats.entrySet()) {
+ writeStats(callId, callingPackageName, entry.getKey(), entry.getValue(), text, latencyMs);
}
+ }
- /** Writes a log event for the given stats. */
- private void writeStats(
- String callId,
- String callingPackageName,
- @Nullable String entityType,
- LinkifyStats stats,
- CharSequence text,
- long latencyMs) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
- callId,
- TextClassifierEvent.TYPE_LINKS_GENERATED,
- /*modelName=*/ null,
- TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN,
- /*eventIndex=*/ 0,
- entityType,
- stats.mNumLinks,
- stats.mNumLinksTextLength,
- text.length(),
- latencyMs,
- callingPackageName);
- if (TcLog.ENABLE_FULL_LOGGING) {
- TcLog.v(
- LOG_TAG,
- String.format(
- Locale.US,
- "%s:%s %d links (%d/%d chars) %dms %s",
- callId,
- entityType,
- stats.mNumLinks,
- stats.mNumLinksTextLength,
- text.length(),
- latencyMs,
- callingPackageName));
- }
+ /**
+ * Returns whether this particular event should be logged.
+ *
+ * <p>Sampling is used to reduce the amount of logging data generated.
+ */
+ private boolean shouldLog() {
+ if (sampleRate <= 1) {
+ return true;
+ } else {
+ return rng.nextInt(sampleRate) == 0;
}
+ }
- /** Helper class for storing per-entity type statistics. */
- private static final class LinkifyStats {
- int mNumLinks;
- int mNumLinksTextLength;
-
- void countLink(TextLinks.TextLink link) {
- mNumLinks += 1;
- mNumLinksTextLength += link.getEnd() - link.getStart();
- }
+ /** Writes a log event for the given stats. */
+ private static void writeStats(
+ String callId,
+ String callingPackageName,
+ @Nullable String entityType,
+ LinkifyStats stats,
+ CharSequence text,
+ long latencyMs) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
+ callId,
+ TextClassifierEvent.TYPE_LINKS_GENERATED,
+ /*modelName=*/ null,
+ TextClassifierEventLogger.WidgetType.WIDGET_TYPE_UNKNOWN,
+ /*eventIndex=*/ 0,
+ entityType,
+ stats.numLinks,
+ stats.numLinksTextLength,
+ text.length(),
+ latencyMs,
+ callingPackageName);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(
+ LOG_TAG,
+ String.format(
+ Locale.US,
+ "%s:%s %d links (%d/%d chars) %dms %s",
+ callId,
+ entityType,
+ stats.numLinks,
+ stats.numLinksTextLength,
+ text.length(),
+ latencyMs,
+ callingPackageName));
}
+ }
+
+ /** Helper class for storing per-entity type statistics. */
+ private static final class LinkifyStats {
+ int numLinks;
+ int numLinksTextLength;
+
+ void countLink(TextLinks.TextLink link) {
+ numLinks += 1;
+ numLinksTextLength += link.getEnd() - link.getStart();
+ }
+ }
}
diff --git a/java/src/com/android/textclassifier/logging/ResultIdUtils.java b/java/src/com/android/textclassifier/logging/ResultIdUtils.java
index de21959..c0f8b2a 100644
--- a/java/src/com/android/textclassifier/logging/ResultIdUtils.java
+++ b/java/src/com/android/textclassifier/logging/ResultIdUtils.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,59 +13,59 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.logging;
//
import android.content.Context;
-
-import androidx.annotation.Nullable;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.StringJoiner;
+import javax.annotation.Nullable;
/** Provide utils to generate and parse the result id. */
public final class ResultIdUtils {
- private static final String LOG_TAG = "ResultIdUtils";
- private static final String CLASSIFIER_ID = "androidtc";
+ private static final String CLASSIFIER_ID = "androidtc";
- /** Creates a string id that may be used to identify a TextClassifier result. */
- public static String createId(
- Context context,
- String text,
- int start,
- int end,
- int modelVersion,
- List<Locale> modelLocales) {
- Preconditions.checkNotNull(text);
- Preconditions.checkNotNull(context);
- Preconditions.checkNotNull(modelLocales);
- final int hash = Objects.hash(text, start, end, context.getPackageName());
- return createId(modelVersion, modelLocales, hash);
- }
+ /** Creates a string id that may be used to identify a TextClassifier result. */
+ public static String createId(
+ Context context,
+ String text,
+ int start,
+ int end,
+ int modelVersion,
+ List<Locale> modelLocales) {
+ Preconditions.checkNotNull(text);
+ Preconditions.checkNotNull(context);
+ Preconditions.checkNotNull(modelLocales);
+ final int hash = Objects.hash(text, start, end, context.getPackageName());
+ return createId(modelVersion, modelLocales, hash);
+ }
- /** Creates a string id that may be used to identify a TextClassifier result. */
- public static String createId(int modelVersion, List<Locale> modelLocales, int hash) {
- final StringJoiner localesJoiner = new StringJoiner(",");
- for (Locale locale : modelLocales) {
- localesJoiner.add(locale.toLanguageTag());
- }
- final String modelName =
- String.format(Locale.US, "%s_v%d", localesJoiner.toString(), modelVersion);
- return String.format(Locale.US, "%s|%s|%d", CLASSIFIER_ID, modelName, hash);
+ /** Creates a string id that may be used to identify a TextClassifier result. */
+ public static String createId(int modelVersion, List<Locale> modelLocales, int hash) {
+ final StringJoiner localesJoiner = new StringJoiner(",");
+ for (Locale locale : modelLocales) {
+ localesJoiner.add(locale.toLanguageTag());
}
+ final String modelName =
+ String.format(Locale.US, "%s_v%d", localesJoiner.toString(), modelVersion);
+ return String.format(Locale.US, "%s|%s|%d", CLASSIFIER_ID, modelName, hash);
+ }
- static String getModelName(@Nullable String signature) {
- if (signature == null) {
- return "";
- }
- final int start = signature.indexOf("|") + 1;
- final int end = signature.indexOf("|", start);
- if (start >= 1 && end >= start) {
- return signature.substring(start, end);
- }
- return "";
+ static String getModelName(@Nullable String signature) {
+ if (signature == null) {
+ return "";
}
+ final int start = signature.indexOf("|") + 1;
+ final int end = signature.indexOf("|", start);
+ if (start >= 1 && end >= start) {
+ return signature.substring(start, end);
+ }
+ return "";
+ }
+
+ private ResultIdUtils() {}
}
diff --git a/java/src/com/android/textclassifier/logging/SelectionEventConverter.java b/java/src/com/android/textclassifier/logging/SelectionEventConverter.java
index 2d492a9..05fdf9f 100644
--- a/java/src/com/android/textclassifier/logging/SelectionEventConverter.java
+++ b/java/src/com/android/textclassifier/logging/SelectionEventConverter.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,84 +19,85 @@
import android.view.textclassifier.SelectionEvent;
import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassifierEvent;
-
-import androidx.annotation.Nullable;
+import javax.annotation.Nullable;
/** Helper class to convert a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
public final class SelectionEventConverter {
- /** Converts a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
- @Nullable
- public static TextClassifierEvent toTextClassifierEvent(SelectionEvent selectionEvent) {
- TextClassificationContext textClassificationContext = null;
- if (selectionEvent.getPackageName() != null && selectionEvent.getWidgetType() != null) {
- textClassificationContext =
- new TextClassificationContext.Builder(
- selectionEvent.getPackageName(), selectionEvent.getWidgetType())
- .setWidgetVersion(selectionEvent.getWidgetVersion())
- .build();
- }
- if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_LINK) {
- return new TextClassifierEvent.TextLinkifyEvent.Builder(
- convertEventType(selectionEvent.getEventType()))
- .setEventContext(textClassificationContext)
- .setResultId(selectionEvent.getResultId())
- .setEventIndex(selectionEvent.getEventIndex())
- .setEntityTypes(selectionEvent.getEntityType())
- .build();
- }
- if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_MANUAL) {
- return new TextClassifierEvent.TextSelectionEvent.Builder(
- convertEventType(selectionEvent.getEventType()))
- .setEventContext(textClassificationContext)
- .setResultId(selectionEvent.getResultId())
- .setEventIndex(selectionEvent.getEventIndex())
- .setEntityTypes(selectionEvent.getEntityType())
- .setRelativeWordStartIndex(selectionEvent.getStart())
- .setRelativeWordEndIndex(selectionEvent.getEnd())
- .setRelativeSuggestedWordStartIndex(selectionEvent.getSmartStart())
- .setRelativeSuggestedWordEndIndex(selectionEvent.getSmartEnd())
- .build();
- }
- return null;
+ /** Converts a {@link SelectionEvent} to a {@link TextClassifierEvent}. */
+ @Nullable
+ public static TextClassifierEvent toTextClassifierEvent(SelectionEvent selectionEvent) {
+ TextClassificationContext textClassificationContext = null;
+ if (selectionEvent.getPackageName() != null && selectionEvent.getWidgetType() != null) {
+ textClassificationContext =
+ new TextClassificationContext.Builder(
+ selectionEvent.getPackageName(), selectionEvent.getWidgetType())
+ .setWidgetVersion(selectionEvent.getWidgetVersion())
+ .build();
}
+ if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_LINK) {
+ return new TextClassifierEvent.TextLinkifyEvent.Builder(
+ convertEventType(selectionEvent.getEventType()))
+ .setEventContext(textClassificationContext)
+ .setResultId(selectionEvent.getResultId())
+ .setEventIndex(selectionEvent.getEventIndex())
+ .setEntityTypes(selectionEvent.getEntityType())
+ .build();
+ }
+ if (selectionEvent.getInvocationMethod() == SelectionEvent.INVOCATION_MANUAL) {
+ return new TextClassifierEvent.TextSelectionEvent.Builder(
+ convertEventType(selectionEvent.getEventType()))
+ .setEventContext(textClassificationContext)
+ .setResultId(selectionEvent.getResultId())
+ .setEventIndex(selectionEvent.getEventIndex())
+ .setEntityTypes(selectionEvent.getEntityType())
+ .setRelativeWordStartIndex(selectionEvent.getStart())
+ .setRelativeWordEndIndex(selectionEvent.getEnd())
+ .setRelativeSuggestedWordStartIndex(selectionEvent.getSmartStart())
+ .setRelativeSuggestedWordEndIndex(selectionEvent.getSmartEnd())
+ .build();
+ }
+ return null;
+ }
- private static int convertEventType(int eventType) {
- switch (eventType) {
- case SelectionEvent.EVENT_SELECTION_STARTED:
- return TextClassifierEvent.TYPE_SELECTION_STARTED;
- case SelectionEvent.EVENT_SELECTION_MODIFIED:
- return TextClassifierEvent.TYPE_SELECTION_MODIFIED;
- case SelectionEvent.EVENT_SMART_SELECTION_SINGLE:
- return SelectionEvent.EVENT_SMART_SELECTION_SINGLE;
- case SelectionEvent.EVENT_SMART_SELECTION_MULTI:
- return SelectionEvent.EVENT_SMART_SELECTION_MULTI;
- case SelectionEvent.EVENT_AUTO_SELECTION:
- return SelectionEvent.EVENT_AUTO_SELECTION;
- case SelectionEvent.ACTION_OVERTYPE:
- return TextClassifierEvent.TYPE_OVERTYPE;
- case SelectionEvent.ACTION_COPY:
- return TextClassifierEvent.TYPE_COPY_ACTION;
- case SelectionEvent.ACTION_PASTE:
- return TextClassifierEvent.TYPE_PASTE_ACTION;
- case SelectionEvent.ACTION_CUT:
- return TextClassifierEvent.TYPE_CUT_ACTION;
- case SelectionEvent.ACTION_SHARE:
- return TextClassifierEvent.TYPE_SHARE_ACTION;
- case SelectionEvent.ACTION_SMART_SHARE:
- return TextClassifierEvent.TYPE_SMART_ACTION;
- case SelectionEvent.ACTION_DRAG:
- return TextClassifierEvent.TYPE_SELECTION_DRAG;
- case SelectionEvent.ACTION_ABANDON:
- return TextClassifierEvent.TYPE_SELECTION_DESTROYED;
- case SelectionEvent.ACTION_OTHER:
- return TextClassifierEvent.TYPE_OTHER_ACTION;
- case SelectionEvent.ACTION_SELECT_ALL:
- return TextClassifierEvent.TYPE_SELECT_ALL;
- case SelectionEvent.ACTION_RESET:
- return TextClassifierEvent.TYPE_SELECTION_RESET;
- default:
- return 0;
- }
+ private static int convertEventType(int eventType) {
+ switch (eventType) {
+ case SelectionEvent.EVENT_SELECTION_STARTED:
+ return TextClassifierEvent.TYPE_SELECTION_STARTED;
+ case SelectionEvent.EVENT_SELECTION_MODIFIED:
+ return TextClassifierEvent.TYPE_SELECTION_MODIFIED;
+ case SelectionEvent.EVENT_SMART_SELECTION_SINGLE:
+ return SelectionEvent.EVENT_SMART_SELECTION_SINGLE;
+ case SelectionEvent.EVENT_SMART_SELECTION_MULTI:
+ return SelectionEvent.EVENT_SMART_SELECTION_MULTI;
+ case SelectionEvent.EVENT_AUTO_SELECTION:
+ return SelectionEvent.EVENT_AUTO_SELECTION;
+ case SelectionEvent.ACTION_OVERTYPE:
+ return TextClassifierEvent.TYPE_OVERTYPE;
+ case SelectionEvent.ACTION_COPY:
+ return TextClassifierEvent.TYPE_COPY_ACTION;
+ case SelectionEvent.ACTION_PASTE:
+ return TextClassifierEvent.TYPE_PASTE_ACTION;
+ case SelectionEvent.ACTION_CUT:
+ return TextClassifierEvent.TYPE_CUT_ACTION;
+ case SelectionEvent.ACTION_SHARE:
+ return TextClassifierEvent.TYPE_SHARE_ACTION;
+ case SelectionEvent.ACTION_SMART_SHARE:
+ return TextClassifierEvent.TYPE_SMART_ACTION;
+ case SelectionEvent.ACTION_DRAG:
+ return TextClassifierEvent.TYPE_SELECTION_DRAG;
+ case SelectionEvent.ACTION_ABANDON:
+ return TextClassifierEvent.TYPE_SELECTION_DESTROYED;
+ case SelectionEvent.ACTION_OTHER:
+ return TextClassifierEvent.TYPE_OTHER_ACTION;
+ case SelectionEvent.ACTION_SELECT_ALL:
+ return TextClassifierEvent.TYPE_SELECT_ALL;
+ case SelectionEvent.ACTION_RESET:
+ return TextClassifierEvent.TYPE_SELECTION_RESET;
+ default:
+ return 0;
}
+ }
+
+ private SelectionEventConverter() {}
}
diff --git a/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java b/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java
index 29e38a1..a112fdf 100644
--- a/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java
+++ b/java/src/com/android/textclassifier/logging/TextClassifierEventLogger.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -13,203 +13,203 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
+
package com.android.textclassifier.logging;
import android.view.textclassifier.TextClassificationContext;
import android.view.textclassifier.TextClassificationSessionId;
import android.view.textclassifier.TextClassifier;
import android.view.textclassifier.TextClassifierEvent;
-
-import androidx.annotation.Nullable;
-import androidx.core.util.Preconditions;
-
import com.android.textclassifier.TcLog;
import com.android.textclassifier.TextClassifierStatsLog;
+import com.google.common.base.Preconditions;
+import javax.annotation.Nullable;
/** Logs {@link android.view.textclassifier.TextClassifierEvent}. */
public final class TextClassifierEventLogger {
- private static final String TAG = "TextClassifierEventLogger";
+ private static final String TAG = "TCEventLogger";
- /** Emits a text classifier event to the logs. */
- public void writeEvent(
- @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
- Preconditions.checkNotNull(event);
- if (TcLog.ENABLE_FULL_LOGGING) {
- TcLog.v(TAG, "TextClassifierEventLogger.writeEvent: event = [" + event + "]");
- }
- if (event instanceof TextClassifierEvent.TextSelectionEvent) {
- logTextSelectionEvent(sessionId, (TextClassifierEvent.TextSelectionEvent) event);
- } else if (event instanceof TextClassifierEvent.TextLinkifyEvent) {
- logTextLinkifyEvent(sessionId, (TextClassifierEvent.TextLinkifyEvent) event);
- } else if (event instanceof TextClassifierEvent.ConversationActionsEvent) {
- logConversationActionsEvent(
- sessionId, (TextClassifierEvent.ConversationActionsEvent) event);
- } else if (event instanceof TextClassifierEvent.LanguageDetectionEvent) {
- logLanguageDetectionEvent(
- sessionId, (TextClassifierEvent.LanguageDetectionEvent) event);
- } else {
- TcLog.w(TAG, "Unexpected events, category=" + event.getEventCategory());
- }
+ /** Emits a text classifier event to the logs. */
+ public void writeEvent(
+ @Nullable TextClassificationSessionId sessionId, TextClassifierEvent event) {
+ Preconditions.checkNotNull(event);
+ if (TcLog.ENABLE_FULL_LOGGING) {
+ TcLog.v(TAG, "TextClassifierEventLogger.writeEvent: event = [" + event + "]");
}
-
- private void logTextSelectionEvent(
- @Nullable TextClassificationSessionId sessionId,
- TextClassifierEvent.TextSelectionEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_SELECTION_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- event.getEventIndex(),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- event.getRelativeWordStartIndex(),
- event.getRelativeWordEndIndex(),
- event.getRelativeSuggestedWordStartIndex(),
- event.getRelativeSuggestedWordEndIndex(),
- getPackageName(event));
+ if (event instanceof TextClassifierEvent.TextSelectionEvent) {
+ logTextSelectionEvent(sessionId, (TextClassifierEvent.TextSelectionEvent) event);
+ } else if (event instanceof TextClassifierEvent.TextLinkifyEvent) {
+ logTextLinkifyEvent(sessionId, (TextClassifierEvent.TextLinkifyEvent) event);
+ } else if (event instanceof TextClassifierEvent.ConversationActionsEvent) {
+ logConversationActionsEvent(sessionId, (TextClassifierEvent.ConversationActionsEvent) event);
+ } else if (event instanceof TextClassifierEvent.LanguageDetectionEvent) {
+ logLanguageDetectionEvent(sessionId, (TextClassifierEvent.LanguageDetectionEvent) event);
+ } else {
+ TcLog.w(TAG, "Unexpected events, category=" + event.getEventCategory());
}
+ }
- private void logTextLinkifyEvent(
- TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- event.getEventIndex(),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- /*numOfLinks=*/ 0,
- /*linkedTextLength=*/ 0,
- /*textLength=*/ 0,
- /*latencyInMillis=*/ 0L,
- getPackageName(event));
+ private void logTextSelectionEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.TextSelectionEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_SELECTION_EVENT,
+ sessionId == null ? null : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ event.getRelativeWordStartIndex(),
+ event.getRelativeWordEndIndex(),
+ event.getRelativeSuggestedWordStartIndex(),
+ event.getRelativeSuggestedWordEndIndex(),
+ getPackageName(event));
+ }
+
+ private void logTextLinkifyEvent(
+ TextClassificationSessionId sessionId, TextClassifierEvent.TextLinkifyEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.TEXT_LINKIFY_EVENT,
+ sessionId == null ? null : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ event.getEventIndex(),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ /*numOfLinks=*/ 0,
+ /*linkedTextLength=*/ 0,
+ /*textLength=*/ 0,
+ /*latencyInMillis=*/ 0L,
+ getPackageName(event));
+ }
+
+ private void logConversationActionsEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.ConversationActionsEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.CONVERSATION_ACTIONS_EVENT,
+ sessionId == null
+ ? event.getResultId() // TODO: Update ExtServices to set the session id.
+ : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getItemAt(event.getEntityTypes(), /* index= */ 1),
+ getItemAt(event.getEntityTypes(), /* index= */ 2),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getPackageName(event));
+ }
+
+ private void logLanguageDetectionEvent(
+ @Nullable TextClassificationSessionId sessionId,
+ TextClassifierEvent.LanguageDetectionEvent event) {
+ TextClassifierStatsLog.write(
+ TextClassifierStatsLog.LANGUAGE_DETECTION_EVENT,
+ sessionId == null ? null : sessionId.flattenToString(),
+ event.getEventType(),
+ getModelName(event),
+ getWidgetType(event),
+ getItemAt(event.getEntityTypes(), /* index= */ 0),
+ getFloatAt(event.getScores(), /* index= */ 0),
+ getIntAt(event.getActionIndices(), /* index= */ 0),
+ getPackageName(event));
+ }
+
+ @Nullable
+ private static <T> T getItemAt(@Nullable T[] array, int index) {
+ if (array == null) {
+ return null;
}
-
- private void logConversationActionsEvent(
- @Nullable TextClassificationSessionId sessionId,
- TextClassifierEvent.ConversationActionsEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.CONVERSATION_ACTIONS_EVENT,
- sessionId == null
- ? event.getResultId() // TODO: Update ExtServices to set the session id.
- : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- getItemAt(event.getEntityTypes(), /* index= */ 1),
- getItemAt(event.getEntityTypes(), /* index= */ 2),
- getFloatAt(event.getScores(), /* index= */ 0),
- getPackageName(event));
+ if (index >= array.length) {
+ return null;
}
+ return array[index];
+ }
- private void logLanguageDetectionEvent(
- @Nullable TextClassificationSessionId sessionId,
- TextClassifierEvent.LanguageDetectionEvent event) {
- TextClassifierStatsLog.write(
- TextClassifierStatsLog.LANGUAGE_DETECTION_EVENT,
- sessionId == null ? null : sessionId.flattenToString(),
- event.getEventType(),
- getModelName(event),
- getWidgetType(event),
- getItemAt(event.getEntityTypes(), /* index= */ 0),
- getFloatAt(event.getScores(), /* index= */ 0),
- getIntAt(event.getActionIndices(), /* index= */ 0),
- getPackageName(event));
+ private static float getFloatAt(@Nullable float[] array, int index) {
+ if (array == null) {
+ return 0f;
}
-
- @Nullable
- private static <T> T getItemAt(@Nullable T[] array, int index) {
- if (array == null) {
- return null;
- }
- if (index >= array.length) {
- return null;
- }
- return array[index];
+ if (index >= array.length) {
+ return 0f;
}
+ return array[index];
+ }
- private static float getFloatAt(@Nullable float[] array, int index) {
- if (array == null) {
- return 0f;
- }
- if (index >= array.length) {
- return 0f;
- }
- return array[index];
+ private static int getIntAt(@Nullable int[] array, int index) {
+ if (array == null) {
+ return 0;
}
-
- private static int getIntAt(@Nullable int[] array, int index) {
- if (array == null) {
- return 0;
- }
- if (index >= array.length) {
- return 0;
- }
- return array[index];
+ if (index >= array.length) {
+ return 0;
}
+ return array[index];
+ }
- private static String getModelName(TextClassifierEvent event) {
- if (event.getModelName() != null) {
- return event.getModelName();
- }
- return ResultIdUtils.getModelName(event.getResultId());
+ private static String getModelName(TextClassifierEvent event) {
+ if (event.getModelName() != null) {
+ return event.getModelName();
}
+ return ResultIdUtils.getModelName(event.getResultId());
+ }
- @Nullable
- private static String getPackageName(TextClassifierEvent event) {
- TextClassificationContext eventContext = event.getEventContext();
- if (eventContext == null) {
- return null;
- }
- return eventContext.getPackageName();
+ @Nullable
+ private static String getPackageName(TextClassifierEvent event) {
+ TextClassificationContext eventContext = event.getEventContext();
+ if (eventContext == null) {
+ return null;
}
+ return eventContext.getPackageName();
+ }
- private static int getWidgetType(TextClassifierEvent event) {
- TextClassificationContext eventContext = event.getEventContext();
- if (eventContext == null) {
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- }
- switch (eventContext.getWidgetType()) {
- case TextClassifier.WIDGET_TYPE_UNKNOWN:
- return WidgetType.WIDGET_TYPE_UNKNOWN;
- case TextClassifier.WIDGET_TYPE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_EDITTEXT:
- return WidgetType.WIDGET_TYPE_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_WEBVIEW:
- return WidgetType.WIDGET_TYPE_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
- return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
- return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT;
- case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
- return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
- case TextClassifier.WIDGET_TYPE_NOTIFICATION:
- return WidgetType.WIDGET_TYPE_NOTIFICATION;
- }
+ private static int getWidgetType(TextClassifierEvent event) {
+ TextClassificationContext eventContext = event.getEventContext();
+ if (eventContext == null) {
+ return WidgetType.WIDGET_TYPE_UNKNOWN;
+ }
+ switch (eventContext.getWidgetType()) {
+ case TextClassifier.WIDGET_TYPE_UNKNOWN:
return WidgetType.WIDGET_TYPE_UNKNOWN;
+ case TextClassifier.WIDGET_TYPE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_EDITTEXT:
+ return WidgetType.WIDGET_TYPE_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_WEBVIEW:
+ return WidgetType.WIDGET_TYPE_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_EDIT_WEBVIEW:
+ return WidgetType.WIDGET_TYPE_EDIT_WEBVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_CUSTOM_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_EDITTEXT:
+ return WidgetType.WIDGET_TYPE_CUSTOM_EDITTEXT;
+ case TextClassifier.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW:
+ return WidgetType.WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW;
+ case TextClassifier.WIDGET_TYPE_NOTIFICATION:
+ return WidgetType.WIDGET_TYPE_NOTIFICATION;
+ default: // fall out
}
+ return WidgetType.WIDGET_TYPE_UNKNOWN;
+ }
- /** Widget type constants for logging. */
- public interface WidgetType {
- // Sync these constants with textclassifier_enums.proto.
- int WIDGET_TYPE_UNKNOWN = 0;
- int WIDGET_TYPE_TEXTVIEW = 1;
- int WIDGET_TYPE_EDITTEXT = 2;
- int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3;
- int WIDGET_TYPE_WEBVIEW = 4;
- int WIDGET_TYPE_EDIT_WEBVIEW = 5;
- int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6;
- int WIDGET_TYPE_CUSTOM_EDITTEXT = 7;
- int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8;
- int WIDGET_TYPE_NOTIFICATION = 9;
- }
+ /** Widget type constants for logging. */
+ public static final class WidgetType {
+ // Sync these constants with textclassifier_enums.proto.
+ public static final int WIDGET_TYPE_UNKNOWN = 0;
+ public static final int WIDGET_TYPE_TEXTVIEW = 1;
+ public static final int WIDGET_TYPE_EDITTEXT = 2;
+ public static final int WIDGET_TYPE_UNSELECTABLE_TEXTVIEW = 3;
+ public static final int WIDGET_TYPE_WEBVIEW = 4;
+ public static final int WIDGET_TYPE_EDIT_WEBVIEW = 5;
+ public static final int WIDGET_TYPE_CUSTOM_TEXTVIEW = 6;
+ public static final int WIDGET_TYPE_CUSTOM_EDITTEXT = 7;
+ public static final int WIDGET_TYPE_CUSTOM_UNSELECTABLE_TEXTVIEW = 8;
+ public static final int WIDGET_TYPE_NOTIFICATION = 9;
+
+ private WidgetType() {}
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java
index 64a9fe8..66db748 100644
--- a/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/BasicLanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,14 +19,11 @@
import android.content.Context;
import android.util.ArrayMap;
import android.view.textclassifier.TextClassifierEvent;
-
import com.android.textclassifier.TextClassificationConstants;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import java.util.Collections;
import java.util.List;
import java.util.Map;
@@ -42,77 +39,74 @@
*/
final class BasicLanguageProficiencyAnalyzer implements LanguageProficiencyAnalyzer {
- private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
- TimeUnit.HOURS.toMillis(6);
+ private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
+ TimeUnit.HOURS.toMillis(6);
- private final TextClassificationConstants mSettings;
- private final LanguageProfileDatabase mDatabase;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final TextClassificationConstants settings;
+ private final LanguageProfileDatabase database;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- private Map<String, Float> mCanUnderstandResultCache;
- private long mCanUnderstandResultCacheTime;
+ private Map<String, Float> canUnderstandResultCache;
+ private long canUnderstandResultCacheTime;
- BasicLanguageProficiencyAnalyzer(
- Context context,
- TextClassificationConstants settings,
- SystemLanguagesProvider systemLanguagesProvider) {
- this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ BasicLanguageProficiencyAnalyzer(
+ Context context,
+ TextClassificationConstants settings,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ }
+
+ @VisibleForTesting
+ BasicLanguageProficiencyAnalyzer(
+ TextClassificationConstants settings,
+ LanguageProfileDatabase languageProfileDatabase,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this.settings = Preconditions.checkNotNull(settings);
+ database = Preconditions.checkNotNull(languageProfileDatabase);
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ canUnderstandResultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public synchronized float canUnderstand(String languageTag) {
+ if (canUnderstandResultCache.isEmpty()
+ || (System.currentTimeMillis() - canUnderstandResultCacheTime)
+ >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
+ canUnderstandResultCache = createCanUnderstandResultCache();
+ canUnderstandResultCacheTime = System.currentTimeMillis();
}
+ return canUnderstandResultCache.getOrDefault(languageTag, 0f);
+ }
- @VisibleForTesting
- BasicLanguageProficiencyAnalyzer(
- TextClassificationConstants settings,
- LanguageProfileDatabase languageProfileDatabase,
- SystemLanguagesProvider systemLanguagesProvider) {
- mSettings = Preconditions.checkNotNull(settings);
- mDatabase = Preconditions.checkNotNull(languageProfileDatabase);
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mCanUnderstandResultCache = new ArrayMap<>();
+ @Override
+ public void onTextClassifierEvent(TextClassifierEvent event) {}
+
+ @Override
+ public boolean shouldShowTranslation(String languageCode) {
+ return canUnderstand(languageCode) >= settings.getTranslateActionThreshold();
+ }
+
+ private Map<String, Float> createCanUnderstandResultCache() {
+ Map<String, Float> result = new ArrayMap<>();
+ List<String> systemLanguageTags = systemLanguagesProvider.getSystemLanguageTags();
+ List<LanguageSignalInfo> languageSignalInfos =
+ database.languageInfoDao().getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
+ // Applies system languages to bootstrap the model according to Zipf's Law.
+ // Zipf’s Law states that the ith most common language should be proportional to 1/i.
+ for (int i = 0; i < systemLanguageTags.size(); i++) {
+ String languageTag = systemLanguageTags.get(i);
+ result.put(
+ languageTag, settings.getLanguageProficiencyBootstrappingCount() / (float) (i + 1));
}
-
- @Override
- public synchronized float canUnderstand(String languageTag) {
- if (mCanUnderstandResultCache.isEmpty()
- || (System.currentTimeMillis() - mCanUnderstandResultCacheTime)
- >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
- mCanUnderstandResultCache = createCanUnderstandResultCache();
- mCanUnderstandResultCacheTime = System.currentTimeMillis();
- }
- return mCanUnderstandResultCache.getOrDefault(languageTag, 0f);
+ // Adds message counts of different languages into the corresponding entry in the map
+ for (LanguageSignalInfo info : languageSignalInfos) {
+ String languageTag = info.getLanguageTag();
+ int count = info.getCount();
+ result.put(languageTag, result.getOrDefault(languageTag, 0f) + count);
}
-
- @Override
- public void onTextClassifierEvent(TextClassifierEvent event) {}
-
- @Override
- public boolean shouldShowTranslation(String languageCode) {
- return canUnderstand(languageCode) >= mSettings.getTranslateActionThreshold();
- }
-
- private Map<String, Float> createCanUnderstandResultCache() {
- Map<String, Float> result = new ArrayMap<>();
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
- List<LanguageSignalInfo> languageSignalInfos =
- mDatabase
- .languageInfoDao()
- .getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
- // Applies system languages to bootstrap the model according to Zipf's Law.
- // Zipf’s Law states that the ith most common language should be proportional to 1/i.
- for (int i = 0; i < systemLanguageTags.size(); i++) {
- String languageTag = systemLanguageTags.get(i);
- result.put(
- languageTag,
- mSettings.getLanguageProficiencyBootstrappingCount() / (float) (i + 1));
- }
- // Adds message counts of different languages into the corresponding entry in the map
- for (LanguageSignalInfo info : languageSignalInfos) {
- String languageTag = info.getLanguageTag();
- int count = info.getCount();
- result.put(languageTag, result.getOrDefault(languageTag, 0f) + count);
- }
- // Calculates confidence scores
- float max = Collections.max(result.values());
- result.forEach((languageTag, count) -> result.put(languageTag, count / max));
- return result;
- }
+ // Calculates confidence scores
+ float max = Collections.max(result.values());
+ result.forEach((languageTag, count) -> result.put(languageTag, count / max));
+ return result;
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java
index e33e45b..28d97e4 100644
--- a/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/KmeansLanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,17 +18,13 @@
import android.content.Context;
import android.view.textclassifier.TextClassifierEvent;
-
import androidx.collection.ArrayMap;
-
import com.android.textclassifier.TextClassificationConstants;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
import com.android.textclassifier.ulp.kmeans.KMeans;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@@ -45,127 +41,126 @@
// STOPSHIP: Review the entire ULP package before shipping it.
final class KmeansLanguageProficiencyAnalyzer implements LanguageProficiencyAnalyzer {
- private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
- TimeUnit.HOURS.toMillis(6);
+ private static final long CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME =
+ TimeUnit.HOURS.toMillis(6);
- private final TextClassificationConstants mSettings;
- private final LanguageProfileDatabase mDatabase;
- private final KMeans mKmeans;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final TextClassificationConstants settings;
+ private final LanguageProfileDatabase database;
+ private final KMeans kmeans;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- private Map<String, Float> mCanUnderstandResultCache;
- private long mCanUnderstandResultCacheTime;
+ private Map<String, Float> canUnderstandResultCache;
+ private long canUnderstandResultCacheTime;
- KmeansLanguageProficiencyAnalyzer(
- Context context,
- TextClassificationConstants settings,
- SystemLanguagesProvider systemLanguagesProvider) {
- this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ KmeansLanguageProficiencyAnalyzer(
+ Context context,
+ TextClassificationConstants settings,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this(settings, LanguageProfileDatabase.getInstance(context), systemLanguagesProvider);
+ }
+
+ @VisibleForTesting
+ KmeansLanguageProficiencyAnalyzer(
+ TextClassificationConstants settings,
+ LanguageProfileDatabase languageProfileDatabase,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this.settings = Preconditions.checkNotNull(settings);
+ database = Preconditions.checkNotNull(languageProfileDatabase);
+ kmeans = new KMeans();
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ canUnderstandResultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public synchronized float canUnderstand(String languageTag) {
+ if (canUnderstandResultCache.isEmpty()
+ || (System.currentTimeMillis() - canUnderstandResultCacheTime)
+ >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
+ canUnderstandResultCache = createCanUnderstandResultCache();
+ canUnderstandResultCacheTime = System.currentTimeMillis();
}
+ return canUnderstandResultCache.getOrDefault(languageTag, 0f);
+ }
- @VisibleForTesting
- KmeansLanguageProficiencyAnalyzer(
- TextClassificationConstants settings,
- LanguageProfileDatabase languageProfileDatabase,
- SystemLanguagesProvider systemLanguagesProvider) {
- mSettings = Preconditions.checkNotNull(settings);
- mDatabase = Preconditions.checkNotNull(languageProfileDatabase);
- mKmeans = new KMeans();
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mCanUnderstandResultCache = new ArrayMap<>();
+ private Map<String, Float> createCanUnderstandResultCache() {
+ Map<String, Float> result = new ArrayMap<>();
+ ArrayMap<String, Integer> languageCounts = new ArrayMap<>();
+ List<String> systemLanguageTags = systemLanguagesProvider.getSystemLanguageTags();
+ List<LanguageSignalInfo> languageSignalInfos =
+ database.languageInfoDao().getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
+ // Applies system languages to bootstrap the model according to Zipf's Law.
+ // Zipf’s Law states that the ith most common language should be proportional to 1/i.
+ for (int i = 0; i < systemLanguageTags.size(); i++) {
+ String languageTag = systemLanguageTags.get(i);
+ languageCounts.put(
+ languageTag, settings.getLanguageProficiencyBootstrappingCount() / (i + 1));
}
-
- @Override
- public synchronized float canUnderstand(String languageTag) {
- if (mCanUnderstandResultCache.isEmpty()
- || (System.currentTimeMillis() - mCanUnderstandResultCacheTime)
- >= CAN_UNDERSTAND_RESULT_CACHE_EXPIRATION_TIME) {
- mCanUnderstandResultCache = createCanUnderstandResultCache();
- mCanUnderstandResultCacheTime = System.currentTimeMillis();
- }
- return mCanUnderstandResultCache.getOrDefault(languageTag, 0f);
+ // Adds message counts of different languages into the corresponding entry in the map
+ for (LanguageSignalInfo info : languageSignalInfos) {
+ String languageTag = info.getLanguageTag();
+ int count = info.getCount();
+ languageCounts.put(languageTag, languageCounts.getOrDefault(languageTag, 0) + count);
}
-
- private Map<String, Float> createCanUnderstandResultCache() {
- Map<String, Float> result = new ArrayMap<>();
- ArrayMap<String, Integer> languageCounts = new ArrayMap<>();
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
- List<LanguageSignalInfo> languageSignalInfos =
- mDatabase
- .languageInfoDao()
- .getBySource(LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
- // Applies system languages to bootstrap the model according to Zipf's Law.
- // Zipf’s Law states that the ith most common language should be proportional to 1/i.
- for (int i = 0; i < systemLanguageTags.size(); i++) {
- String languageTag = systemLanguageTags.get(i);
- languageCounts.put(
- languageTag, mSettings.getLanguageProficiencyBootstrappingCount() / (i + 1));
- }
- // Adds message counts of different languages into the corresponding entry in the map
- for (LanguageSignalInfo info : languageSignalInfos) {
- String languageTag = info.getLanguageTag();
- int count = info.getCount();
- languageCounts.put(languageTag, languageCounts.getOrDefault(languageTag, 0) + count);
- }
- // Calculates confidence scores
- if (languageCounts.size() == 1) {
- result.put(languageCounts.keyAt(0), 1f);
- return result;
- }
- if (languageCounts.size() == 2) {
- return evaluateTwoLanguageCounts(languageCounts);
- }
- // Applies K-Means to cluster data points
- int size = languageCounts.size();
- float[][] inputData = new float[size][1];
- for (int i = 0; i < size; i++) {
- inputData[i][0] = languageCounts.valueAt(i);
- }
- List<KMeans.Mean> means = mKmeans.predict(/* k= */ 2, inputData);
- List<Integer> countsInMaxCluster = getCountsWithinFarthestCluster(means);
- for (int i = 0; i < languageCounts.size(); i++) {
- float score = countsInMaxCluster.contains(languageCounts.valueAt(i)) ? 1f : 0f;
- result.put(languageCounts.keyAt(i), score);
- }
- return result;
+ // Calculates confidence scores
+ if (languageCounts.size() == 1) {
+ result.put(languageCounts.keyAt(0), 1f);
+ return result;
}
-
- @Override
- public void onTextClassifierEvent(TextClassifierEvent event) {}
-
- @Override
- public boolean shouldShowTranslation(String languageCode) {
- return canUnderstand(languageCode) >= mSettings.getTranslateActionThreshold();
+ if (languageCounts.size() == 2) {
+ return evaluateTwoLanguageCounts(languageCounts);
}
-
- private Map<String, Float> evaluateTwoLanguageCounts(ArrayMap<String, Integer> languageCounts) {
- Map<String, Float> result = new ArrayMap<>();
- int countOne = languageCounts.valueAt(0);
- String languageTagOne = languageCounts.keyAt(0);
- int countTwo = languageCounts.valueAt(1);
- String languageTagTwo = languageCounts.keyAt(1);
- if (countOne >= countTwo) {
- result.put(languageTagOne, 1f);
- result.put(languageTagTwo, countTwo / (float) countOne);
- } else {
- result.put(languageTagTwo, 1f);
- result.put(languageTagOne, countOne / (float) countTwo);
- }
- return result;
+ // Applies K-Means to cluster data points
+ int size = languageCounts.size();
+ float[][] inputData = new float[size][1];
+ for (int i = 0; i < size; i++) {
+ inputData[i][0] = languageCounts.valueAt(i);
}
-
- private List<Integer> getCountsWithinFarthestCluster(List<KMeans.Mean> means) {
- List<Integer> result = new ArrayList<>();
- KMeans.Mean farthestMean = means.get(0);
- for (int i = 1; i < means.size(); i++) {
- KMeans.Mean curMean = means.get(i);
- if (curMean.getCentroid()[0] > farthestMean.getCentroid()[0]) {
- farthestMean = curMean;
- }
- }
- for (float[] item : farthestMean.getItems()) {
- result.add((int) item[0]);
- }
- return result;
+ List<KMeans.Mean> means = kmeans.predict(/* k= */ 2, inputData);
+ List<Integer> countsInMaxCluster = getCountsWithinFarthestCluster(means);
+ for (int i = 0; i < languageCounts.size(); i++) {
+ float score = countsInMaxCluster.contains(languageCounts.valueAt(i)) ? 1f : 0f;
+ result.put(languageCounts.keyAt(i), score);
}
+ return result;
+ }
+
+ @Override
+ public void onTextClassifierEvent(TextClassifierEvent event) {}
+
+ @Override
+ public boolean shouldShowTranslation(String languageCode) {
+ return canUnderstand(languageCode) >= settings.getTranslateActionThreshold();
+ }
+
+ private static Map<String, Float> evaluateTwoLanguageCounts(
+ ArrayMap<String, Integer> languageCounts) {
+ Map<String, Float> result = new ArrayMap<>();
+ int countOne = languageCounts.valueAt(0);
+ String languageTagOne = languageCounts.keyAt(0);
+ int countTwo = languageCounts.valueAt(1);
+ String languageTagTwo = languageCounts.keyAt(1);
+ if (countOne >= countTwo) {
+ result.put(languageTagOne, 1f);
+ result.put(languageTagTwo, countTwo / (float) countOne);
+ } else {
+ result.put(languageTagTwo, 1f);
+ result.put(languageTagOne, countOne / (float) countTwo);
+ }
+ return result;
+ }
+
+ private static List<Integer> getCountsWithinFarthestCluster(List<KMeans.Mean> means) {
+ List<Integer> result = new ArrayList<>();
+ KMeans.Mean farthestMean = means.get(0);
+ for (int i = 1; i < means.size(); i++) {
+ KMeans.Mean curMean = means.get(i);
+ if (curMean.getCentroid()[0] > farthestMean.getCentroid()[0]) {
+ farthestMean = curMean;
+ }
+ }
+ for (float[] item : farthestMean.getItems()) {
+ result.add((int) item[0]);
+ }
+ return result;
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java
index 89ea236..c9e8695 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,14 +17,13 @@
package com.android.textclassifier.ulp;
import android.view.textclassifier.TextClassifierEvent;
-
import androidx.annotation.FloatRange;
interface LanguageProficiencyAnalyzer {
- @FloatRange(from = 0.0, to = 1.0)
- float canUnderstand(String languageCode);
+ @FloatRange(from = 0.0, to = 1.0)
+ float canUnderstand(String languageCode);
- void onTextClassifierEvent(TextClassifierEvent event);
+ void onTextClassifierEvent(TextClassifierEvent event);
- boolean shouldShowTranslation(String languageCode);
+ boolean shouldShowTranslation(String languageCode);
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java b/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java
index c8ed732..7c16418 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProficiencyEvaluator.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,106 +18,104 @@
import androidx.collection.ArrayMap;
import androidx.collection.ArraySet;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.util.Set;
final class LanguageProficiencyEvaluator {
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- LanguageProficiencyEvaluator(SystemLanguagesProvider systemLanguagesProvider) {
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ LanguageProficiencyEvaluator(SystemLanguagesProvider systemLanguagesProvider) {
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ }
+
+ EvaluationResult evaluate(LanguageProficiencyAnalyzer analyzer, Set<String> languagesToEvaluate) {
+ Set<String> systemLanguageTags =
+ new ArraySet<>(systemLanguagesProvider.getSystemLanguageTags());
+ ArrayMap<String, Boolean> actual = new ArrayMap<>();
+ // We assume user can only speak the languages that are set as system languages.
+ for (String languageToEvaluate : languagesToEvaluate) {
+ actual.put(languageToEvaluate, systemLanguageTags.contains(languageToEvaluate));
+ }
+ return evaluateWithActual(analyzer, actual);
+ }
+
+ private static EvaluationResult evaluateWithActual(
+ LanguageProficiencyAnalyzer analyzer, ArrayMap<String, Boolean> actual) {
+ ArrayMap<String, Boolean> predict = new ArrayMap<>();
+ for (int i = 0; i < actual.size(); i++) {
+ String languageTag = actual.keyAt(i);
+ predict.put(languageTag, analyzer.canUnderstand(languageTag) >= 0.5f);
+ }
+ return EvaluationResult.create(actual, predict);
+ }
+
+ static final class EvaluationResult {
+ final int truePositive;
+ final int trueNegative;
+ final int falsePositive;
+ final int falseNegative;
+
+ private EvaluationResult(
+ int truePositive, int trueNegative, int falsePositive, int falseNegative) {
+ this.truePositive = truePositive;
+ this.trueNegative = trueNegative;
+ this.falsePositive = falsePositive;
+ this.falseNegative = falseNegative;
}
- EvaluationResult evaluate(
- LanguageProficiencyAnalyzer analyzer, Set<String> languagesToEvaluate) {
- Set<String> systemLanguageTags =
- new ArraySet<>(mSystemLanguagesProvider.getSystemLanguageTags());
- ArrayMap<String, Boolean> actual = new ArrayMap<>();
- // We assume user can only speak the languages that are set as system languages.
- for (String languageToEvaluate : languagesToEvaluate) {
- actual.put(languageToEvaluate, systemLanguageTags.contains(languageToEvaluate));
- }
- return evaluateWithActual(analyzer, actual);
+ float computePrecisionOfPositiveClass() {
+ float divisor = truePositive + falsePositive;
+ return divisor != 0 ? truePositive / divisor : 1f;
}
- private EvaluationResult evaluateWithActual(
- LanguageProficiencyAnalyzer analyzer, ArrayMap<String, Boolean> actual) {
- ArrayMap<String, Boolean> predict = new ArrayMap<>();
- for (int i = 0; i < actual.size(); i++) {
- String languageTag = actual.keyAt(i);
- predict.put(languageTag, analyzer.canUnderstand(languageTag) >= 0.5f);
- }
- return EvaluationResult.create(actual, predict);
+ float computePrecisionOfNegativeClass() {
+ float divisor = trueNegative + falseNegative;
+ return divisor != 0 ? trueNegative / divisor : 1f;
}
- static final class EvaluationResult {
- final int truePositive;
- final int trueNegative;
- final int falsePositive;
- final int falseNegative;
-
- private EvaluationResult(
- int truePositive, int trueNegative, int falsePositive, int falseNegative) {
- this.truePositive = truePositive;
- this.trueNegative = trueNegative;
- this.falsePositive = falsePositive;
- this.falseNegative = falseNegative;
- }
-
- float computePrecisionOfPositiveClass() {
- float divisor = truePositive + falsePositive;
- return divisor != 0 ? truePositive / divisor : 1f;
- }
-
- float computePrecisionOfNegativeClass() {
- float divisor = trueNegative + falseNegative;
- return divisor != 0 ? trueNegative / divisor : 1f;
- }
-
- float computeRecallOfPositiveClass() {
- float divisor = truePositive + falseNegative;
- return divisor != 0 ? truePositive / divisor : 1f;
- }
-
- float computeRecallOfNegativeClass() {
- float divisor = trueNegative + falsePositive;
- return divisor != 0 ? trueNegative / divisor : 1f;
- }
-
- float computeF1ScoreOfPositiveClass() {
- return 2 * truePositive / (float) (2 * truePositive + falsePositive + falseNegative);
- }
-
- float computeF1ScoreOfNegativeClass() {
- return 2 * trueNegative / (float) (2 * trueNegative + falsePositive + falseNegative);
- }
-
- static EvaluationResult create(
- ArrayMap<String, Boolean> actual, ArrayMap<String, Boolean> predict) {
- int truePositive = 0;
- int trueNegative = 0;
- int falsePositive = 0;
- int falseNegative = 0;
- for (int i = 0; i < actual.size(); i++) {
- String languageTag = actual.keyAt(i);
- boolean actualLabel = actual.valueAt(i);
- boolean predictLabel = predict.get(languageTag);
- if (predictLabel) {
- if (actualLabel == predictLabel) {
- truePositive += 1;
- } else {
- falsePositive += 1;
- }
- } else {
- if (actualLabel == predictLabel) {
- trueNegative += 1;
- } else {
- falseNegative += 1;
- }
- }
- }
- return new EvaluationResult(truePositive, trueNegative, falsePositive, falseNegative);
- }
+ float computeRecallOfPositiveClass() {
+ float divisor = truePositive + falseNegative;
+ return divisor != 0 ? truePositive / divisor : 1f;
}
+
+ float computeRecallOfNegativeClass() {
+ float divisor = trueNegative + falsePositive;
+ return divisor != 0 ? trueNegative / divisor : 1f;
+ }
+
+ float computeF1ScoreOfPositiveClass() {
+ return 2 * truePositive / (float) (2 * truePositive + falsePositive + falseNegative);
+ }
+
+ float computeF1ScoreOfNegativeClass() {
+ return 2 * trueNegative / (float) (2 * trueNegative + falsePositive + falseNegative);
+ }
+
+ static EvaluationResult create(
+ ArrayMap<String, Boolean> actual, ArrayMap<String, Boolean> predict) {
+ int truePositive = 0;
+ int trueNegative = 0;
+ int falsePositive = 0;
+ int falseNegative = 0;
+ for (int i = 0; i < actual.size(); i++) {
+ String languageTag = actual.keyAt(i);
+ boolean actualLabel = actual.valueAt(i);
+ boolean predictLabel = predict.get(languageTag);
+ if (predictLabel) {
+ if (actualLabel == predictLabel) {
+ truePositive += 1;
+ } else {
+ falsePositive += 1;
+ }
+ } else {
+ if (actualLabel == predictLabel) {
+ trueNegative += 1;
+ } else {
+ falseNegative += 1;
+ }
+ }
+ }
+ return new EvaluationResult(truePositive, trueNegative, falsePositive, falseNegative);
+ }
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java b/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java
index ac220f3..4c436dc 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProfileAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,21 +19,17 @@
import android.content.Context;
import android.util.ArrayMap;
import android.view.textclassifier.TextClassifierEvent;
-
import androidx.annotation.FloatRange;
-import androidx.annotation.NonNull;
-
import com.android.textclassifier.Entity;
import com.android.textclassifier.TcLog;
import com.android.textclassifier.TextClassificationConstants;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
import com.android.textclassifier.utils.IndentingPrintWriter;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
-
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
@@ -51,188 +47,177 @@
* blocking operations and should be called on the worker thread.
*/
public class LanguageProfileAnalyzer {
- private final Context mContext;
- private final TextClassificationConstants mTextClassificationConstants;
- private final LanguageProfileDatabase mLanguageProfileDatabase;
- private final LanguageProficiencyAnalyzer mProficiencyAnalyzer;
- private final LocationSignalProvider mLocationSignalProvider;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
+ private final Context context;
+ private final TextClassificationConstants textClassificationConstants;
+ private final LanguageProfileDatabase languageProfileDatabase;
+ private final LanguageProficiencyAnalyzer proficiencyAnalyzer;
+ private final LocationSignalProvider locationSignalProvider;
+ private final SystemLanguagesProvider systemLanguagesProvider;
- @VisibleForTesting
- LanguageProfileAnalyzer(
- Context context,
- TextClassificationConstants textClassificationConstants,
- LanguageProfileDatabase database,
- LanguageProficiencyAnalyzer languageProficiencyAnalyzer,
- LocationSignalProvider locationSignalProvider,
- SystemLanguagesProvider systemLanguagesProvider) {
- mContext = context;
- mTextClassificationConstants = textClassificationConstants;
- mLanguageProfileDatabase = Preconditions.checkNotNull(database);
- mProficiencyAnalyzer = Preconditions.checkNotNull(languageProficiencyAnalyzer);
- mLocationSignalProvider = Preconditions.checkNotNull(locationSignalProvider);
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ @VisibleForTesting
+ LanguageProfileAnalyzer(
+ Context context,
+ TextClassificationConstants textClassificationConstants,
+ LanguageProfileDatabase database,
+ LanguageProficiencyAnalyzer languageProficiencyAnalyzer,
+ LocationSignalProvider locationSignalProvider,
+ SystemLanguagesProvider systemLanguagesProvider) {
+ this.context = context;
+ this.textClassificationConstants = textClassificationConstants;
+ languageProfileDatabase = Preconditions.checkNotNull(database);
+ proficiencyAnalyzer = Preconditions.checkNotNull(languageProficiencyAnalyzer);
+ this.locationSignalProvider = Preconditions.checkNotNull(locationSignalProvider);
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ }
+
+ /** Creates an instance of {@link LanguageProfileAnalyzer}. */
+ public static LanguageProfileAnalyzer create(
+ Context context, TextClassificationConstants textClassificationConstants) {
+ SystemLanguagesProvider systemLanguagesProvider = new SystemLanguagesProvider();
+ LocationSignalProvider locationSignalProvider = new LocationSignalProvider(context);
+ return new LanguageProfileAnalyzer(
+ context,
+ textClassificationConstants,
+ LanguageProfileDatabase.getInstance(context),
+ new ReinforcementLanguageProficiencyAnalyzer(context, systemLanguagesProvider),
+ locationSignalProvider,
+ systemLanguagesProvider);
+ }
+
+ /**
+ * Returns the confidence score for which the user understands the given language. The result is
+ * recalculated every constant time.
+ *
+ * <p>The score ranges from 0 to 1. 1 indicates the language is very familiar to the user and vice
+ * versa.
+ */
+ @FloatRange(from = 0.0, to = 1.0)
+ public float canUnderstand(String languageTag) {
+ return proficiencyAnalyzer.canUnderstand(languageTag);
+ }
+
+ /** Decides whether we should show translation for that language or no. */
+ public boolean shouldShowTranslation(String languageTag) {
+ return proficiencyAnalyzer.shouldShowTranslation(languageTag);
+ }
+
+ /** Performs actions defined for specific TextClassification events. */
+ public void onTextClassifierEven(TextClassifierEvent event) {
+ proficiencyAnalyzer.onTextClassifierEvent(event);
+ }
+
+ /**
+ * Returns a list of languages that appear in the specified source, the list is sorted by the
+ * frequency descendingly. The confidence score represents how frequent of the language is,
+ * compared to the most frequent language.
+ */
+ public List<Entity> getFrequentLanguages(@LanguageSignalInfo.Source int source) {
+ List<LanguageSignalInfo> languageSignalInfos =
+ languageProfileDatabase.languageInfoDao().getBySource(source);
+ int bootstrappingCount = textClassificationConstants.getFrequentLanguagesBootstrappingCount();
+ ArrayMap<String, Integer> languageCountMap = new ArrayMap<>();
+ systemLanguagesProvider
+ .getSystemLanguageTags()
+ .forEach(lang -> languageCountMap.put(lang, bootstrappingCount));
+ String languageTagFromLocation = locationSignalProvider.detectLanguageTag();
+ if (languageTagFromLocation != null) {
+ languageCountMap.put(
+ languageTagFromLocation,
+ languageCountMap.getOrDefault(languageTagFromLocation, 0) + bootstrappingCount);
}
-
- /** Creates an instance of {@link LanguageProfileAnalyzer}. */
- public static LanguageProfileAnalyzer create(
- Context context, TextClassificationConstants textClassificationConstants) {
- SystemLanguagesProvider systemLanguagesProvider = new SystemLanguagesProvider();
- LocationSignalProvider locationSignalProvider = new LocationSignalProvider(context);
- return new LanguageProfileAnalyzer(
- context,
- textClassificationConstants,
- LanguageProfileDatabase.getInstance(context),
- new ReinforcementLanguageProficiencyAnalyzer(context, systemLanguagesProvider),
- locationSignalProvider,
- systemLanguagesProvider);
+ for (LanguageSignalInfo languageSignalInfo : languageSignalInfos) {
+ String lang = languageSignalInfo.getLanguageTag();
+ languageCountMap.put(
+ lang, languageSignalInfo.getCount() + languageCountMap.getOrDefault(lang, 0));
}
-
- /**
- * Returns the confidence score for which the user understands the given language. The result is
- * recalculated every constant time.
- *
- * <p>The score ranges from 0 to 1. 1 indicates the language is very familiar to the user and
- * vice versa.
- */
- @FloatRange(from = 0.0, to = 1.0)
- public float canUnderstand(String languageTag) {
- return mProficiencyAnalyzer.canUnderstand(languageTag);
+ int max = Collections.max(languageCountMap.values());
+ if (max == 0) {
+ return ImmutableList.of();
}
-
- /** Decides whether we should show translation for that language or no. */
- public boolean shouldShowTranslation(String languageTag) {
- return mProficiencyAnalyzer.shouldShowTranslation(languageTag);
+ List<Entity> frequentLanguages = new ArrayList<>();
+ for (int i = 0; i < languageCountMap.size(); i++) {
+ String lang = languageCountMap.keyAt(i);
+ float score = languageCountMap.valueAt(i) / (float) max;
+ frequentLanguages.add(new Entity(lang, score));
}
+ Collections.sort(frequentLanguages);
+ return ImmutableList.copyOf(frequentLanguages);
+ }
- /** Performs actions defined for specific TextClassification events. */
- public void onTextClassifierEven(TextClassifierEvent event) {
- mProficiencyAnalyzer.onTextClassifierEvent(event);
+ /** Dumps the data on the screen when called. */
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("LanguageProfileAnalyzer:");
+ printWriter.increaseIndent();
+ printWriter.printPair(
+ "System languages", String.join(",", systemLanguagesProvider.getSystemLanguageTags()));
+ printWriter.printPair(
+ "Language code deduced from location", locationSignalProvider.detectLanguageTag());
+
+ ExecutorService executorService =
+ MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor());
+ try {
+ executorService
+ .submit(
+ () -> {
+ printWriter.println("Languages that user has seen in selections:");
+ dumpFrequentLanguages(printWriter, LanguageSignalInfo.CLASSIFY_TEXT);
+
+ printWriter.println("Languages that user has seen in message notifications:");
+ dumpFrequentLanguages(printWriter, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
+
+ dumpEvaluationReport(printWriter);
+ })
+ .get();
+ } catch (ExecutionException | InterruptedException e) {
+ TcLog.e(TcLog.TAG, "Dumping interrupted: ", e);
}
+ printWriter.decreaseIndent();
+ }
- /**
- * Returns a list of languages that appear in the specified source, the list is sorted by the
- * frequency descendingly. The confidence score represents how frequent of the language is,
- * compared to the most frequent language.
- */
- @NonNull
- public List<Entity> getFrequentLanguages(@LanguageSignalInfo.Source int source) {
- List<LanguageSignalInfo> languageSignalInfos =
- mLanguageProfileDatabase.languageInfoDao().getBySource(source);
- int bootstrappingCount =
- mTextClassificationConstants.getFrequentLanguagesBootstrappingCount();
- ArrayMap<String, Integer> languageCountMap = new ArrayMap<>();
- mSystemLanguagesProvider
- .getSystemLanguageTags()
- .forEach(lang -> languageCountMap.put(lang, bootstrappingCount));
- String languageTagFromLocation = mLocationSignalProvider.detectLanguageTag();
- if (languageTagFromLocation != null) {
- languageCountMap.put(
- languageTagFromLocation,
- languageCountMap.getOrDefault(languageTagFromLocation, 0) + bootstrappingCount);
- }
- for (LanguageSignalInfo languageSignalInfo : languageSignalInfos) {
- String lang = languageSignalInfo.getLanguageTag();
- languageCountMap.put(
- lang, languageSignalInfo.getCount() + languageCountMap.getOrDefault(lang, 0));
- }
- int max = Collections.max(languageCountMap.values());
- if (max == 0) {
- return Collections.emptyList();
- }
- List<Entity> frequentLanguages = new ArrayList<>();
- for (int i = 0; i < languageCountMap.size(); i++) {
- String lang = languageCountMap.keyAt(i);
- float score = languageCountMap.valueAt(i) / (float) max;
- frequentLanguages.add(new Entity(lang, score));
- }
- Collections.sort(frequentLanguages);
- return frequentLanguages;
+ private void dumpEvaluationReport(IndentingPrintWriter printWriter) {
+ List<String> systemLanguageTags = systemLanguagesProvider.getSystemLanguageTags();
+ if (systemLanguageTags.size() <= 1) {
+ printWriter.println("Skipped evaluation as there are less than two system languages.");
+ return;
}
-
- /** Dumps the data on the screen when called. */
- public void dump(IndentingPrintWriter printWriter) {
- printWriter.println("LanguageProfileAnalyzer:");
- printWriter.increaseIndent();
- printWriter.printPair(
- "System languages",
- String.join(",", mSystemLanguagesProvider.getSystemLanguageTags()));
- printWriter.printPair(
- "Language code deduced from location", mLocationSignalProvider.detectLanguageTag());
-
- ExecutorService executorService =
- MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor());
- try {
- executorService
- .submit(
- () -> {
- printWriter.println("Languages that user has seen in selections:");
- dumpFrequentLanguages(
- printWriter, LanguageSignalInfo.CLASSIFY_TEXT);
-
- printWriter.println(
- "Languages that user has seen in message notifications:");
- dumpFrequentLanguages(
- printWriter,
- LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS);
-
- dumpEvaluationReport(printWriter);
- })
- .get();
- } catch (ExecutionException | InterruptedException e) {
- TcLog.e(TcLog.TAG, "Dumping interrupted: ", e);
- }
- printWriter.decreaseIndent();
+ Set<String> languagesToEvaluate =
+ languageProfileDatabase.languageInfoDao().getAll().stream()
+ .map(LanguageSignalInfo::getLanguageTag)
+ .collect(Collectors.toSet());
+ languagesToEvaluate.addAll(systemLanguageTags);
+ LanguageProficiencyEvaluator evaluator =
+ new LanguageProficiencyEvaluator(systemLanguagesProvider);
+ LanguageProficiencyAnalyzer[] analyzers =
+ new LanguageProficiencyAnalyzer[] {
+ new BasicLanguageProficiencyAnalyzer(
+ context, textClassificationConstants, systemLanguagesProvider),
+ new KmeansLanguageProficiencyAnalyzer(
+ context, textClassificationConstants, systemLanguagesProvider),
+ proficiencyAnalyzer
+ };
+ for (LanguageProficiencyAnalyzer analyzer : analyzers) {
+ LanguageProficiencyEvaluator.EvaluationResult result =
+ evaluator.evaluate(analyzer, languagesToEvaluate);
+ printWriter.println("Evaluation result of " + analyzer.getClass().getSimpleName());
+ printWriter.increaseIndent();
+ printWriter.printPair(
+ "Precision of positive class", result.computePrecisionOfPositiveClass());
+ printWriter.printPair(
+ "Precision of negative class", result.computePrecisionOfNegativeClass());
+ printWriter.printPair("Recall of positive class", result.computeRecallOfPositiveClass());
+ printWriter.printPair("Recall of negative class", result.computeRecallOfNegativeClass());
+ printWriter.printPair("F1 score of positive class", result.computeF1ScoreOfPositiveClass());
+ printWriter.printPair("F1 score of negative class", result.computeF1ScoreOfNegativeClass());
+ printWriter.decreaseIndent();
}
+ }
- private void dumpEvaluationReport(IndentingPrintWriter printWriter) {
- List<String> systemLanguageTags = mSystemLanguagesProvider.getSystemLanguageTags();
- if (systemLanguageTags.size() <= 1) {
- printWriter.println("Skipped evaluation as there are less than two system languages.");
- return;
- }
- Set<String> languagesToEvaluate =
- mLanguageProfileDatabase.languageInfoDao().getAll().stream()
- .map(LanguageSignalInfo::getLanguageTag)
- .collect(Collectors.toSet());
- languagesToEvaluate.addAll(systemLanguageTags);
- LanguageProficiencyEvaluator evaluator =
- new LanguageProficiencyEvaluator(mSystemLanguagesProvider);
- LanguageProficiencyAnalyzer[] analyzers =
- new LanguageProficiencyAnalyzer[] {
- new BasicLanguageProficiencyAnalyzer(
- mContext, mTextClassificationConstants, mSystemLanguagesProvider),
- new KmeansLanguageProficiencyAnalyzer(
- mContext, mTextClassificationConstants, mSystemLanguagesProvider),
- mProficiencyAnalyzer
- };
- for (LanguageProficiencyAnalyzer analyzer : analyzers) {
- LanguageProficiencyEvaluator.EvaluationResult result =
- evaluator.evaluate(analyzer, languagesToEvaluate);
- printWriter.println("Evaluation result of " + analyzer.getClass().getSimpleName());
- printWriter.increaseIndent();
- printWriter.printPair(
- "Precision of positive class", result.computePrecisionOfPositiveClass());
- printWriter.printPair(
- "Precision of negative class", result.computePrecisionOfNegativeClass());
- printWriter.printPair(
- "Recall of positive class", result.computeRecallOfPositiveClass());
- printWriter.printPair(
- "Recall of negative class", result.computeRecallOfNegativeClass());
- printWriter.printPair(
- "F1 score of positive class", result.computeF1ScoreOfPositiveClass());
- printWriter.printPair(
- "F1 score of negative class", result.computeF1ScoreOfNegativeClass());
- printWriter.decreaseIndent();
- }
+ private void dumpFrequentLanguages(
+ IndentingPrintWriter printWriter, @LanguageSignalInfo.Source int source) {
+ printWriter.increaseIndent();
+ for (Entity frequentLanguage : getFrequentLanguages(source)) {
+ printWriter.println(frequentLanguage.toString());
}
-
- private void dumpFrequentLanguages(
- IndentingPrintWriter printWriter, @LanguageSignalInfo.Source int source) {
- printWriter.increaseIndent();
- for (Entity frequentLanguage : getFrequentLanguages(source)) {
- printWriter.println(frequentLanguage.toString());
- }
- printWriter.decreaseIndent();
- }
+ printWriter.decreaseIndent();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java b/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java
index 55ba729..56f878f 100644
--- a/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java
+++ b/java/src/com/android/textclassifier/ulp/LanguageProfileUpdater.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,158 +19,154 @@
import android.content.Context;
import android.util.LruCache;
import android.view.textclassifier.ConversationActions;
-
-import androidx.annotation.VisibleForTesting;
-
import com.android.textclassifier.TcLog;
import com.android.textclassifier.ulp.database.LanguageProfileDatabase;
import com.android.textclassifier.ulp.database.LanguageSignalInfo;
import com.android.textclassifier.utils.IndentingPrintWriter;
-
+import com.google.common.annotations.VisibleForTesting;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
-
+import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.time.Instant;
-import java.time.OffsetDateTime;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.function.Function;
-
import javax.annotation.Nullable;
/** Class implementing functions which builds and updates user language profile. */
public class LanguageProfileUpdater {
- private static final String TAG = "LanguageProfileUpdater";
- private static final int MAX_CACHE_SIZE = 20;
- private static final String DEFAULT_NOTIFICATION_KEY = "DEFAULT_KEY";
+ private static final String TAG = "LanguageProfileUpdater";
+ private static final int MAX_CACHE_SIZE = 20;
+ private static final String DEFAULT_NOTIFICATION_KEY = "DEFAULT_KEY";
- static final String NOTIFICATION_KEY = "notificationKey";
+ static final String NOTIFICATION_KEY = "notificationKey";
- private final LanguageProfileDatabase mLanguageProfileDatabase;
- private final ListeningExecutorService mExecutorService;
- private final LruCache<String, Long> mUpdatedNotifications = new LruCache<>(MAX_CACHE_SIZE);
+ private final LanguageProfileDatabase languageProfileDatabase;
+ private final ListeningExecutorService executorService;
+ private final LruCache<String, Long> updatedNotifications = new LruCache<>(MAX_CACHE_SIZE);
- public LanguageProfileUpdater(Context context, ListeningExecutorService executorService) {
- mLanguageProfileDatabase = LanguageProfileDatabase.getInstance(context);
- mExecutorService = executorService;
- }
+ public LanguageProfileUpdater(Context context, ListeningExecutorService executorService) {
+ languageProfileDatabase = LanguageProfileDatabase.getInstance(context);
+ this.executorService = executorService;
+ }
- @VisibleForTesting
- LanguageProfileUpdater(
- ListeningExecutorService executorService, LanguageProfileDatabase database) {
- mLanguageProfileDatabase = database;
- mExecutorService = executorService;
- }
+ @VisibleForTesting
+ LanguageProfileUpdater(
+ ListeningExecutorService executorService, LanguageProfileDatabase database) {
+ languageProfileDatabase = database;
+ this.executorService = executorService;
+ }
- /** Updates counts of languages found in suggestConversationActions. */
- public ListenableFuture<Void> updateFromConversationActionsAsync(
- ConversationActions.Request request,
- Function<CharSequence, List<String>> languageDetector) {
- return runAsync(
- () -> {
- ConversationActions.Message msg = getMessageFromRequest(request);
- if (msg == null) {
- return null;
- }
- List<String> languageTags = languageDetector.apply(msg.getText().toString());
- String notificationKey =
- request.getExtras()
- .getString(NOTIFICATION_KEY, DEFAULT_NOTIFICATION_KEY);
- Long messageReferenceTime = getMessageReferenceTime(msg);
- if (isNewMessage(notificationKey, messageReferenceTime)) {
- for (String tag : languageTags) {
- increaseSignalCountInDatabase(
- tag, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1);
- }
- }
- return null;
- });
- }
-
- /** Updates counts of languages found in classifyText. */
- public ListenableFuture<Void> updateFromClassifyTextAsync(List<String> detectedLanguageTags) {
- return runAsync(
- () -> {
- for (String languageTag : detectedLanguageTags) {
- increaseSignalCountInDatabase(
- languageTag, LanguageSignalInfo.CLASSIFY_TEXT, /* increment= */ 1);
- }
- return null;
- });
- }
-
- /** Runs the specified callable asynchronously and prints the stack trace if it failed. */
- private <T> ListenableFuture<T> runAsync(Callable<T> callable) {
- ListenableFuture<T> future = mExecutorService.submit(callable);
- Futures.addCallback(
- future,
- new FutureCallback<T>() {
- @Override
- public void onSuccess(T result) {}
-
- @Override
- public void onFailure(Throwable t) {
- TcLog.e(TAG, "runAsync", t);
- }
- },
- MoreExecutors.directExecutor());
- return future;
- }
-
- private void increaseSignalCountInDatabase(
- String languageTag, @LanguageSignalInfo.Source int sourceType, int increment) {
- mLanguageProfileDatabase
- .languageInfoDao()
- .increaseSignalCount(languageTag, sourceType, increment);
- }
-
- @Nullable
- private ConversationActions.Message getMessageFromRequest(ConversationActions.Request request) {
- int size = request.getConversation().size();
- if (size == 0) {
+ /** Updates counts of languages found in suggestConversationActions. */
+ @CanIgnoreReturnValue
+ public ListenableFuture<Void> updateFromConversationActionsAsync(
+ ConversationActions.Request request, Function<CharSequence, List<String>> languageDetector) {
+ return runAsync(
+ () -> {
+ ConversationActions.Message msg = getMessageFromRequest(request);
+ if (msg == null) {
return null;
- }
- return request.getConversation().get(size - 1);
- }
+ }
+ List<String> languageTags = languageDetector.apply(msg.getText().toString());
+ String notificationKey =
+ request.getExtras().getString(NOTIFICATION_KEY, DEFAULT_NOTIFICATION_KEY);
+ Long messageReferenceTime = getMessageReferenceTime(msg);
+ if (isNewMessage(notificationKey, messageReferenceTime)) {
+ for (String tag : languageTags) {
+ increaseSignalCountInDatabase(
+ tag, LanguageSignalInfo.SUGGEST_CONVERSATION_ACTIONS, 1);
+ }
+ }
+ return null;
+ });
+ }
- private boolean isNewMessage(String notificationKey, Long sendTime) {
- Long oldTime = mUpdatedNotifications.get(notificationKey);
+ /** Updates counts of languages found in classifyText. */
+ @CanIgnoreReturnValue
+ public ListenableFuture<Void> updateFromClassifyTextAsync(List<String> detectedLanguageTags) {
+ return runAsync(
+ () -> {
+ for (String languageTag : detectedLanguageTags) {
+ increaseSignalCountInDatabase(
+ languageTag, LanguageSignalInfo.CLASSIFY_TEXT, /* increment= */ 1);
+ }
+ return null;
+ });
+ }
- if (oldTime == null || sendTime > oldTime) {
- mUpdatedNotifications.put(notificationKey, sendTime);
- return true;
- }
- return false;
- }
+ /** Runs the specified callable asynchronously and prints the stack trace if it failed. */
+ private <T> ListenableFuture<T> runAsync(Callable<T> callable) {
+ ListenableFuture<T> future = executorService.submit(callable);
+ Futures.addCallback(
+ future,
+ new FutureCallback<T>() {
+ @Override
+ public void onSuccess(T result) {}
- private long getMessageReferenceTime(ConversationActions.Message msg) {
- return msg.getReferenceTime() == null
- ? OffsetDateTime.now().toInstant().toEpochMilli()
- : msg.getReferenceTime().toInstant().toEpochMilli();
- }
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "runAsync", t);
+ }
+ },
+ MoreExecutors.directExecutor());
+ return future;
+ }
- /** Dumps the data on the screen when called. */
- public void dump(IndentingPrintWriter printWriter) {
- printWriter.println("LanguageProfileUpdater:");
- printWriter.increaseIndent();
- printWriter.println("Cache for notifications status:");
- printWriter.increaseIndent();
- for (Map.Entry<String, Long> entry : mUpdatedNotifications.snapshot().entrySet()) {
- long timestamp = entry.getValue();
- printWriter.println(
- "Notification key: "
- + entry.getKey()
- + " time: "
- + timestamp
- + " ("
- + Instant.ofEpochMilli(timestamp).toString()
- + ")");
- }
- printWriter.decreaseIndent();
- printWriter.decreaseIndent();
+ private void increaseSignalCountInDatabase(
+ String languageTag, @LanguageSignalInfo.Source int sourceType, int increment) {
+ languageProfileDatabase
+ .languageInfoDao()
+ .increaseSignalCount(languageTag, sourceType, increment);
+ }
+
+ @Nullable
+ private static ConversationActions.Message getMessageFromRequest(
+ ConversationActions.Request request) {
+ int size = request.getConversation().size();
+ if (size == 0) {
+ return null;
}
+ return request.getConversation().get(size - 1);
+ }
+
+ private boolean isNewMessage(String notificationKey, Long sendTime) {
+ Long oldTime = updatedNotifications.get(notificationKey);
+
+ if (oldTime == null || sendTime > oldTime) {
+ updatedNotifications.put(notificationKey, sendTime);
+ return true;
+ }
+ return false;
+ }
+
+ private static long getMessageReferenceTime(ConversationActions.Message msg) {
+ return msg.getReferenceTime() == null
+ ? Instant.now().toEpochMilli()
+ : msg.getReferenceTime().toInstant().toEpochMilli();
+ }
+
+ /** Dumps the data on the screen when called. */
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("LanguageProfileUpdater:");
+ printWriter.increaseIndent();
+ printWriter.println("Cache for notifications status:");
+ printWriter.increaseIndent();
+ for (Map.Entry<String, Long> entry : updatedNotifications.snapshot().entrySet()) {
+ long timestamp = entry.getValue();
+ printWriter.println(
+ "Notification key: "
+ + entry.getKey()
+ + " time: "
+ + timestamp
+ + " ("
+ + Instant.ofEpochMilli(timestamp).toString()
+ + ")");
+ }
+ printWriter.decreaseIndent();
+ printWriter.decreaseIndent();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java b/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java
index c3ad77a..8b678b9 100644
--- a/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java
+++ b/java/src/com/android/textclassifier/ulp/LocationSignalProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -23,103 +23,99 @@
import android.location.Location;
import android.location.LocationManager;
import android.telephony.TelephonyManager;
-
-import androidx.annotation.Nullable;
-
import com.android.textclassifier.TcLog;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
final class LocationSignalProvider {
- private static final String TAG = "LocationSignalProvider";
+ private static final String TAG = "LocationSignalProvider";
- private static final long EXPIRATION_TIME = TimeUnit.DAYS.toMillis(1);
- private final LocationManager mLocationManager;
- private final TelephonyManager mTelephonyManager;
- private final Geocoder mGeocoder;
- private final Object mLock = new Object();
+ private static final long EXPIRATION_TIME = TimeUnit.DAYS.toMillis(1);
+ private final LocationManager locationManager;
+ private final TelephonyManager telephonyManager;
+ private final Geocoder geocoder;
+ private final Object lock = new Object();
- private String mCachedLanguageTag = null;
- private long mLastModifiedTime;
+ private String cachedLanguageTag = null;
+ private long lastModifiedTime;
- LocationSignalProvider(Context context) {
- this(
- context.getSystemService(LocationManager.class),
- context.getSystemService(TelephonyManager.class),
- new Geocoder(context));
+ LocationSignalProvider(Context context) {
+ this(
+ context.getSystemService(LocationManager.class),
+ context.getSystemService(TelephonyManager.class),
+ new Geocoder(context));
+ }
+
+ @VisibleForTesting
+ LocationSignalProvider(
+ LocationManager locationManager, TelephonyManager telephonyManager, Geocoder geocoder) {
+ this.locationManager = Preconditions.checkNotNull(locationManager);
+ this.telephonyManager = Preconditions.checkNotNull(telephonyManager);
+ this.geocoder = Preconditions.checkNotNull(geocoder);
+ }
+
+ /**
+ * Deduces the language by using user's current location as a signal and returns a BCP 47 language
+ * code.
+ */
+ @Nullable
+ String detectLanguageTag() {
+ synchronized (lock) {
+ if ((System.currentTimeMillis() - lastModifiedTime) < EXPIRATION_TIME) {
+ return cachedLanguageTag;
+ }
+ cachedLanguageTag = detectLanguageCodeInternal();
+ lastModifiedTime = System.currentTimeMillis();
+ return cachedLanguageTag;
}
+ }
- @VisibleForTesting
- LocationSignalProvider(
- LocationManager locationManager, TelephonyManager telephonyManager, Geocoder geocoder) {
- mLocationManager = Preconditions.checkNotNull(locationManager);
- mTelephonyManager = Preconditions.checkNotNull(telephonyManager);
- mGeocoder = Preconditions.checkNotNull(geocoder);
+ @Nullable
+ private String detectLanguageCodeInternal() {
+ String currentCountryCode = detectCurrentCountryCode();
+ if (currentCountryCode == null) {
+ return null;
}
+ return toLanguageTag(currentCountryCode);
+ }
- /**
- * Deduces the language by using user's current location as a signal and returns a BCP 47
- * language code.
- */
- @Nullable
- String detectLanguageTag() {
- synchronized (mLock) {
- if ((System.currentTimeMillis() - mLastModifiedTime) < EXPIRATION_TIME) {
- return mCachedLanguageTag;
- }
- mCachedLanguageTag = detectLanguageCodeInternal();
- mLastModifiedTime = System.currentTimeMillis();
- return mCachedLanguageTag;
- }
+ @Nullable
+ private String detectCurrentCountryCode() {
+ String networkCountryCode = telephonyManager.getNetworkCountryIso();
+ if (networkCountryCode != null) {
+ return networkCountryCode;
}
+ return detectCurrentCountryFromLocationManager();
+ }
- @Nullable
- private String detectLanguageCodeInternal() {
- String currentCountryCode = detectCurrentCountryCode();
- if (currentCountryCode == null) {
- return null;
- }
- return toLanguageTag(currentCountryCode);
+ @Nullable
+ private String detectCurrentCountryFromLocationManager() {
+ Location location = locationManager.getLastKnownLocation(LocationManager.PASSIVE_PROVIDER);
+ if (location == null) {
+ return null;
}
+ try {
+ List<Address> addresses =
+ geocoder.getFromLocation(
+ location.getLatitude(), location.getLongitude(), /*maxResults=*/ 1);
+ if (addresses != null && !addresses.isEmpty()) {
+ return addresses.get(0).getCountryCode();
+ }
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to call getFromLocation: ", e);
+ return null;
+ }
+ return null;
+ }
- @Nullable
- private String detectCurrentCountryCode() {
- String networkCountryCode = mTelephonyManager.getNetworkCountryIso();
- if (networkCountryCode != null) {
- return networkCountryCode;
- }
- return detectCurrentCountryFromLocationManager();
- }
-
- @Nullable
- private String detectCurrentCountryFromLocationManager() {
- Location location = mLocationManager.getLastKnownLocation(LocationManager.PASSIVE_PROVIDER);
- if (location == null) {
- return null;
- }
- try {
- List<Address> addresses =
- mGeocoder.getFromLocation(
- location.getLatitude(), location.getLongitude(), /*maxResults=*/ 1);
- if (addresses != null && !addresses.isEmpty()) {
- return addresses.get(0).getCountryCode();
- }
- } catch (IOException e) {
- TcLog.e(TAG, "Failed to call getFromLocation: ", e);
- return null;
- }
- return null;
- }
-
- @Nullable
- private static String toLanguageTag(String countryTag) {
- ULocale locale = new ULocale.Builder().setRegion(countryTag).build();
- locale = ULocale.addLikelySubtags(locale);
- return locale.getLanguage();
- }
+ @Nullable
+ private static String toLanguageTag(String countryTag) {
+ ULocale locale = new ULocale.Builder().setRegion(countryTag).build();
+ locale = ULocale.addLikelySubtags(locale);
+ return locale.getLanguage();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java b/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java
index 851e81b..d23be60 100644
--- a/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java
+++ b/java/src/com/android/textclassifier/ulp/ReinforcementLanguageProficiencyAnalyzer.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -19,12 +19,9 @@
import android.content.Context;
import android.content.SharedPreferences;
import android.view.textclassifier.TextClassifierEvent;
-
import com.android.textclassifier.TcLog;
-
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
-
import org.json.JSONException;
import org.json.JSONObject;
@@ -37,159 +34,159 @@
* translation action in the future.
*/
class ReinforcementLanguageProficiencyAnalyzer implements LanguageProficiencyAnalyzer {
- private static final String TAG = "ReinforcementAnalyzer";
- private static final String PREF_NAME = "ulp-reinforcement-analyzer";
- private static final float SHOW_TRANSLATE_ACTION_THRESHOLD = 0.9f;
- private static final int MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT = 30;
+ private static final String TAG = "ReinforcementAnalyzer";
+ private static final String PREF_NAME = "ulp-reinforcement-analyzer";
+ private static final float SHOW_TRANSLATE_ACTION_THRESHOLD = 0.9f;
+ private static final int MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT = 30;
- private final SystemLanguagesProvider mSystemLanguagesProvider;
- private final SharedPreferences mSharedPreferences;
+ private final SystemLanguagesProvider systemLanguagesProvider;
+ private final SharedPreferences sharedPreferences;
- ReinforcementLanguageProficiencyAnalyzer(
- Context context, SystemLanguagesProvider systemLanguagesProvider) {
- Preconditions.checkNotNull(context);
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mSharedPreferences = context.getSharedPreferences(PREF_NAME, Context.MODE_PRIVATE);
+ ReinforcementLanguageProficiencyAnalyzer(
+ Context context, SystemLanguagesProvider systemLanguagesProvider) {
+ Preconditions.checkNotNull(context);
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ sharedPreferences = context.getSharedPreferences(PREF_NAME, Context.MODE_PRIVATE);
+ }
+
+ @VisibleForTesting
+ ReinforcementLanguageProficiencyAnalyzer(
+ SystemLanguagesProvider systemLanguagesProvider, SharedPreferences sharedPreferences) {
+ this.systemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
+ this.sharedPreferences = Preconditions.checkNotNull(sharedPreferences);
+ }
+
+ @Override
+ public float canUnderstand(String languageTag) {
+ TranslationStatistics translationStatistics =
+ TranslationStatistics.loadFromSharedPreference(sharedPreferences, languageTag);
+ if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
+ return systemLanguagesProvider.getSystemLanguageTags().contains(languageTag) ? 1f : 0f;
+ }
+ return translationStatistics.getScore();
+ }
+
+ @Override
+ public boolean shouldShowTranslation(String languageTag) {
+ TranslationStatistics translationStatistics =
+ TranslationStatistics.loadFromSharedPreference(sharedPreferences, languageTag);
+ if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
+ // Show translate action until we have enough feedback.
+ return true;
+ }
+ return translationStatistics.getScore() <= SHOW_TRANSLATE_ACTION_THRESHOLD;
+ }
+
+ @Override
+ public void onTextClassifierEvent(TextClassifierEvent event) {
+ if (event.getEventCategory() == TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION) {
+ if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION
+ || event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
+ onTranslateEvent(event);
+ }
+ }
+ }
+
+ private void onTranslateEvent(TextClassifierEvent event) {
+ if (event.getEntityTypes().length == 0) {
+ return;
+ }
+ String languageTag = event.getEntityTypes()[0];
+ // We only count the case that we show translate action in the prime position.
+ if (event.getActionIndices().length == 0 || event.getActionIndices()[0] != 0) {
+ return;
+ }
+ TranslationStatistics translationStatistics =
+ TranslationStatistics.loadFromSharedPreference(sharedPreferences, languageTag);
+ if (event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
+ translationStatistics.increaseShownCountByOne();
+ } else if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION) {
+ translationStatistics.increaseClickedCountByOne();
+ }
+ translationStatistics.save(sharedPreferences, languageTag);
+ }
+
+ private static final class TranslationStatistics {
+ private static final String SEEN_COUNT = "seen_count";
+ private static final String CLICK_COUNT = "click_count";
+
+ private int shownCount;
+ private int clickCount;
+
+ private TranslationStatistics() {
+ this(/* seenCount= */ 0, /* clickCount= */ 0);
}
- @VisibleForTesting
- ReinforcementLanguageProficiencyAnalyzer(
- SystemLanguagesProvider systemLanguagesProvider, SharedPreferences sharedPreferences) {
- mSystemLanguagesProvider = Preconditions.checkNotNull(systemLanguagesProvider);
- mSharedPreferences = Preconditions.checkNotNull(sharedPreferences);
+ private TranslationStatistics(int seenCount, int clickCount) {
+ shownCount = seenCount;
+ this.clickCount = clickCount;
+ }
+
+ static TranslationStatistics loadFromSharedPreference(
+ SharedPreferences sharedPreferences, String languageTag) {
+ String serializedString = sharedPreferences.getString(languageTag, null);
+ return TranslationStatistics.fromSerializedString(serializedString);
+ }
+
+ void save(SharedPreferences sharedPreferences, String languageTag) {
+ // TODO: Consider to store it in a database.
+ sharedPreferences.edit().putString(languageTag, serializeToString()).apply();
+ }
+
+ private String serializeToString() {
+ try {
+ JSONObject jsonObject = new JSONObject();
+ jsonObject.put(SEEN_COUNT, shownCount);
+ jsonObject.put(CLICK_COUNT, clickCount);
+ return jsonObject.toString();
+ } catch (JSONException ex) {
+ TcLog.e(TAG, "serializeToString: ", ex);
+ }
+ return "";
+ }
+
+ void increaseShownCountByOne() {
+ shownCount += 1;
+ }
+
+ void increaseClickedCountByOne() {
+ clickCount += 1;
+ }
+
+ float getScore() {
+ if (shownCount == 0) {
+ return 0f;
+ }
+ return clickCount / (float) shownCount;
+ }
+
+ int getShownCount() {
+ return shownCount;
+ }
+
+ static TranslationStatistics fromSerializedString(String str) {
+ if (str == null) {
+ return new TranslationStatistics();
+ }
+ try {
+ JSONObject jsonObject = new JSONObject(str);
+ int seenCount = jsonObject.getInt(SEEN_COUNT);
+ int clickCount = jsonObject.getInt(CLICK_COUNT);
+ return new TranslationStatistics(seenCount, clickCount);
+ } catch (JSONException ex) {
+ TcLog.e(TAG, "Failed to parse " + str, ex);
+ }
+ return new TranslationStatistics();
}
@Override
- public float canUnderstand(String languageTag) {
- TranslationStatistics translationStatistics =
- TranslationStatistics.loadFromSharedPreference(mSharedPreferences, languageTag);
- if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
- return mSystemLanguagesProvider.getSystemLanguageTags().contains(languageTag) ? 1f : 0f;
- }
- return translationStatistics.getScore();
+ public String toString() {
+ return "TranslationStatistics{"
+ + "mShownCount="
+ + shownCount
+ + ", mClickCount="
+ + clickCount
+ + '}';
}
-
- @Override
- public boolean shouldShowTranslation(String languageTag) {
- TranslationStatistics translationStatistics =
- TranslationStatistics.loadFromSharedPreference(mSharedPreferences, languageTag);
- if (translationStatistics.getShownCount() < MIN_NUM_TRANSLATE_SHOWN_TO_BE_CONFIDENT) {
- // Show translate action until we have enough feedback.
- return true;
- }
- return translationStatistics.getScore() <= SHOW_TRANSLATE_ACTION_THRESHOLD;
- }
-
- @Override
- public void onTextClassifierEvent(TextClassifierEvent event) {
- if (event.getEventCategory() == TextClassifierEvent.CATEGORY_LANGUAGE_DETECTION) {
- if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION
- || event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
- onTranslateEvent(event);
- }
- }
- }
-
- private void onTranslateEvent(TextClassifierEvent event) {
- if (event.getEntityTypes().length == 0) {
- return;
- }
- String languageTag = event.getEntityTypes()[0];
- // We only count the case that we show translate action in the prime position.
- if (event.getActionIndices().length == 0 || event.getActionIndices()[0] != 0) {
- return;
- }
- TranslationStatistics translationStatistics =
- TranslationStatistics.loadFromSharedPreference(mSharedPreferences, languageTag);
- if (event.getEventType() == TextClassifierEvent.TYPE_ACTIONS_SHOWN) {
- translationStatistics.increaseShownCountByOne();
- } else if (event.getEventType() == TextClassifierEvent.TYPE_SMART_ACTION) {
- translationStatistics.increaseClickedCountByOne();
- }
- translationStatistics.save(mSharedPreferences, languageTag);
- }
-
- private static final class TranslationStatistics {
- private static final String SEEN_COUNT = "seen_count";
- private static final String CLICK_COUNT = "click_count";
-
- private int mShownCount;
- private int mClickCount;
-
- private TranslationStatistics() {
- this(/* seenCount= */ 0, /* clickCount= */ 0);
- }
-
- private TranslationStatistics(int seenCount, int clickCount) {
- mShownCount = seenCount;
- mClickCount = clickCount;
- }
-
- static TranslationStatistics loadFromSharedPreference(
- SharedPreferences sharedPreferences, String languageTag) {
- String serializedString = sharedPreferences.getString(languageTag, null);
- return TranslationStatistics.fromSerializedString(serializedString);
- }
-
- void save(SharedPreferences sharedPreferences, String languageTag) {
- // TODO: Consider to store it in a database.
- sharedPreferences.edit().putString(languageTag, serializeToString()).apply();
- }
-
- private String serializeToString() {
- try {
- JSONObject jsonObject = new JSONObject();
- jsonObject.put(SEEN_COUNT, mShownCount);
- jsonObject.put(CLICK_COUNT, mClickCount);
- return jsonObject.toString();
- } catch (JSONException ex) {
- TcLog.e(TAG, "serializeToString: ", ex);
- }
- return "";
- }
-
- void increaseShownCountByOne() {
- mShownCount += 1;
- }
-
- void increaseClickedCountByOne() {
- mClickCount += 1;
- }
-
- float getScore() {
- if (mShownCount == 0) {
- return 0f;
- }
- return mClickCount / (float) mShownCount;
- }
-
- int getShownCount() {
- return mShownCount;
- }
-
- static TranslationStatistics fromSerializedString(String str) {
- if (str == null) {
- return new TranslationStatistics();
- }
- try {
- JSONObject jsonObject = new JSONObject(str);
- int seenCount = jsonObject.getInt(SEEN_COUNT);
- int clickCount = jsonObject.getInt(CLICK_COUNT);
- return new TranslationStatistics(seenCount, clickCount);
- } catch (JSONException ex) {
- TcLog.e(TAG, "Failed to parse " + str, ex);
- }
- return new TranslationStatistics();
- }
-
- @Override
- public String toString() {
- return "TranslationStatistics{"
- + "mShownCount="
- + mShownCount
- + ", mClickCount="
- + mClickCount
- + '}';
- }
- }
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java b/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java
index e167566..eefa913 100644
--- a/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java
+++ b/java/src/com/android/textclassifier/ulp/SystemLanguagesProvider.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,18 +18,17 @@
import android.content.res.Resources;
import android.os.LocaleList;
-
import java.util.ArrayList;
import java.util.List;
final class SystemLanguagesProvider {
- List<String> getSystemLanguageTags() {
- LocaleList localeList = Resources.getSystem().getConfiguration().getLocales();
- List<String> languageTags = new ArrayList<>();
- for (int i = 0; i < localeList.size(); i++) {
- languageTags.add(localeList.get(i).getLanguage());
- }
- return languageTags;
+ List<String> getSystemLanguageTags() {
+ LocaleList localeList = Resources.getSystem().getConfiguration().getLocales();
+ List<String> languageTags = new ArrayList<>();
+ for (int i = 0; i < localeList.size(); i++) {
+ languageTags.add(localeList.get(i).getLanguage());
}
+ return languageTags;
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java b/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java
index 2ae317e..2e213d5 100644
--- a/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java
+++ b/java/src/com/android/textclassifier/ulp/database/LanguageProfileDatabase.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
package com.android.textclassifier.ulp.database;
import android.content.Context;
-
import androidx.annotation.GuardedBy;
import androidx.room.Database;
import androidx.room.Room;
@@ -30,36 +29,36 @@
* existing if there is one) and use it.
*/
@Database(
- entities = {LanguageSignalInfo.class},
- version = 1,
- exportSchema = false)
+ entities = {LanguageSignalInfo.class},
+ version = 1,
+ exportSchema = false)
public abstract class LanguageProfileDatabase extends RoomDatabase {
- private static final Object sLock = new Object();
+ private static final Object lock = new Object();
- @GuardedBy("sLock")
- private static LanguageProfileDatabase sINSTANCE;
+ @GuardedBy("lock")
+ private static LanguageProfileDatabase instance;
- /**
- * Returns {@link LanguageSignalInfoDao} object belonging to the {@link LanguageProfileDatabase}
- * with which we can call database queries.
- */
- public abstract LanguageSignalInfoDao languageInfoDao();
+ /**
+ * Returns {@link LanguageSignalInfoDao} object belonging to the {@link LanguageProfileDatabase}
+ * with which we can call database queries.
+ */
+ public abstract LanguageSignalInfoDao languageInfoDao();
- /**
- * Create an instance of {@link LanguageProfileDatabase} for chosen context or existing one if
- * it was already created.
- */
- public static LanguageProfileDatabase getInstance(final Context context) {
- synchronized (sLock) {
- if (sINSTANCE == null) {
- sINSTANCE =
- Room.databaseBuilder(
- context.getApplicationContext(),
- LanguageProfileDatabase.class,
- "language_profile")
- .build();
- }
- return sINSTANCE;
- }
+ /**
+ * Create an instance of {@link LanguageProfileDatabase} for chosen context or existing one if it
+ * was already created.
+ */
+ public static LanguageProfileDatabase getInstance(final Context context) {
+ synchronized (lock) {
+ if (instance == null) {
+ instance =
+ Room.databaseBuilder(
+ context.getApplicationContext(),
+ LanguageProfileDatabase.class,
+ "language_profile")
+ .build();
+ }
+ return instance;
}
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java
index b494e93..98f63ac 100644
--- a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java
+++ b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfo.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -18,10 +18,9 @@
import androidx.annotation.IntDef;
import androidx.annotation.NonNull;
-import androidx.core.util.Preconditions;
import androidx.room.ColumnInfo;
import androidx.room.Entity;
-
+import com.google.common.base.Preconditions;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
@@ -32,77 +31,78 @@
* specified language.
*/
@Entity(
- tableName = "language_signal_infos",
- primaryKeys = {"languageTag", "source"})
+ tableName = "language_signal_infos",
+ primaryKeys = {"languageTag", "source"})
public final class LanguageSignalInfo {
- @Retention(RetentionPolicy.SOURCE)
- @IntDef({SUGGEST_CONVERSATION_ACTIONS, CLASSIFY_TEXT})
- public @interface Source {}
+ /** The source of the signal */
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef({SUGGEST_CONVERSATION_ACTIONS, CLASSIFY_TEXT})
+ public @interface Source {}
- public static final int SUGGEST_CONVERSATION_ACTIONS = 0;
- public static final int CLASSIFY_TEXT = 1;
+ public static final int SUGGEST_CONVERSATION_ACTIONS = 0;
+ public static final int CLASSIFY_TEXT = 1;
- @NonNull
- @ColumnInfo(name = "languageTag")
- private String mLanguageTag;
+ @NonNull
+ @ColumnInfo(name = "languageTag")
+ private final String languageTag;
- @ColumnInfo(name = "source")
- private int mSource;
+ @ColumnInfo(name = "source")
+ private final int source;
- @ColumnInfo(name = "count")
- private int mCount;
+ @ColumnInfo(name = "count")
+ private final int count;
- public LanguageSignalInfo(String languageTag, @Source int source, int count) {
- mLanguageTag = Preconditions.checkNotNull(languageTag);
- mSource = source;
- mCount = count;
+ public LanguageSignalInfo(String languageTag, @Source int source, int count) {
+ this.languageTag = Preconditions.checkNotNull(languageTag);
+ this.source = source;
+ this.count = count;
+ }
+
+ public String getLanguageTag() {
+ return languageTag;
+ }
+
+ @Source
+ public int getSource() {
+ return source;
+ }
+
+ public int getCount() {
+ return count;
+ }
+
+ @Override
+ public String toString() {
+ String src = "OTHER";
+ if (source == SUGGEST_CONVERSATION_ACTIONS) {
+ src = "SUGGEST_CONVERSATION_ACTIONS";
+ } else if (source == CLASSIFY_TEXT) {
+ src = "CLASSIFY_TEXT";
}
- public String getLanguageTag() {
- return mLanguageTag;
- }
+ return languageTag + "_" + src + ": " + count;
+ }
- @Source
- public int getSource() {
- return mSource;
- }
+ @Override
+ public int hashCode() {
+ int result = languageTag.hashCode();
+ result = 31 * result + Integer.hashCode(source);
+ result = 31 * result + Integer.hashCode(count);
+ return result;
+ }
- public int getCount() {
- return mCount;
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
}
-
- @Override
- public String toString() {
- String src = "OTHER";
- if (mSource == SUGGEST_CONVERSATION_ACTIONS) {
- src = "SUGGEST_CONVERSATION_ACTIONS";
- } else if (mSource == CLASSIFY_TEXT) {
- src = "CLASSIFY_TEXT";
- }
-
- return mLanguageTag + "_" + src + ": " + mCount;
+ if (obj == null || obj.getClass() != getClass()) {
+ return false;
}
-
- @Override
- public int hashCode() {
- int result = mLanguageTag.hashCode();
- result = 31 * result + Integer.hashCode(mSource);
- result = 31 * result + Integer.hashCode(mCount);
- return result;
- }
-
- @Override
- public boolean equals(Object obj) {
- if (this == obj) {
- return true;
- }
- if (obj == null || obj.getClass() != getClass()) {
- return false;
- }
- LanguageSignalInfo info = (LanguageSignalInfo) obj;
- return mLanguageTag.equals(info.getLanguageTag())
- && mSource == info.getSource()
- && mCount == info.getCount();
- }
+ LanguageSignalInfo info = (LanguageSignalInfo) obj;
+ return languageTag.equals(info.getLanguageTag())
+ && source == info.getSource()
+ && count == info.getCount();
+ }
}
diff --git a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java
index adc1a59..ca792b2 100644
--- a/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java
+++ b/java/src/com/android/textclassifier/ulp/database/LanguageSignalInfoDao.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -20,9 +20,6 @@
import androidx.room.Insert;
import androidx.room.OnConflictStrategy;
import androidx.room.Query;
-
-import com.google.common.annotations.VisibleForTesting;
-
import java.util.List;
/**
@@ -34,31 +31,30 @@
@Dao
public interface LanguageSignalInfoDao {
- /**
- * Inserts {@link LanguageSignalInfo} object into the Room database. If there was already entity
- * with the same language local and source type replaces it with a passed one.
- *
- * @param languageSignalInfo object to insert into the database.
- */
- @Insert(onConflict = OnConflictStrategy.REPLACE)
- @VisibleForTesting
- void insertLanguageInfo(LanguageSignalInfo languageSignalInfo);
+ /**
+ * Inserts {@link LanguageSignalInfo} object into the Room database. If there was already entity
+ * with the same language local and source type replaces it with a passed one.
+ *
+ * @param languageSignalInfo object to insert into the database.
+ */
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ void insertLanguageInfo(LanguageSignalInfo languageSignalInfo);
- /** Returns all the {@link LanguageSignalInfo} objects which have source like {@code src}. */
- @Query("SELECT * FROM language_signal_infos WHERE source = :src")
- List<LanguageSignalInfo> getBySource(@LanguageSignalInfo.Source int src);
+ /** Returns all the {@link LanguageSignalInfo} objects which have source like {@code src}. */
+ @Query("SELECT * FROM language_signal_infos WHERE source = :src")
+ List<LanguageSignalInfo> getBySource(@LanguageSignalInfo.Source int src);
- /** Returns all the {@link LanguageSignalInfo} objects stored in the database. */
- @Query("SELECT * FROM language_signal_infos")
- List<LanguageSignalInfo> getAll();
+ /** Returns all the {@link LanguageSignalInfo} objects stored in the database. */
+ @Query("SELECT * FROM language_signal_infos")
+ List<LanguageSignalInfo> getAll();
- /**
- * Increases the count of the specified signal by the specified increment or inserts a new entry
- * if the signal is not in the database yet.
- */
- @Query(
- "INSERT INTO language_signal_infos VALUES(:languageTag, :source, :increment)"
- + " ON CONFLICT(languageTag, source) DO UPDATE SET count = count + :increment")
- void increaseSignalCount(
- String languageTag, @LanguageSignalInfo.Source int source, int increment);
+ /**
+ * Increases the count of the specified signal by the specified increment or inserts a new entry
+ * if the signal is not in the database yet.
+ */
+ @Query(
+ "INSERT INTO language_signal_infos VALUES(:languageTag, :source, :increment)"
+ + " ON CONFLICT(languageTag, source) DO UPDATE SET count = count + :increment")
+ void increaseSignalCount(
+ String languageTag, @LanguageSignalInfo.Source int source, int increment);
}
diff --git a/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java b/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java
index e77b05e..a01af54 100644
--- a/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java
+++ b/java/src/com/android/textclassifier/ulp/kmeans/KMeans.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -17,11 +17,7 @@
package com.android.textclassifier.ulp.kmeans;
import android.util.Log;
-
-import androidx.annotation.NonNull;
-
import com.google.common.annotations.VisibleForTesting;
-
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -29,226 +25,218 @@
/** Simple K-Means implementation which found in android internal ml library */
public class KMeans {
- private static final boolean DEBUG = false;
- private static final String TAG = "KMeans";
- private final Random mRandomState;
- private final int mMaxIterations;
- private float mSqConvergenceEpsilon;
+ private static final boolean DEBUG = false;
+ private static final String TAG = "KMeans";
+ private final Random randomState;
+ private final int maxIterations;
+ private final float sqConvergenceEpsilon;
- public KMeans() {
- this(new Random());
+ public KMeans() {
+ this(new Random());
+ }
+
+ public KMeans(Random random) {
+ this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
+ }
+
+ public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
+ randomState = random;
+ this.maxIterations = maxIterations;
+ sqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
+ }
+
+ /**
+ * Runs k-means on the input data (X) trying to find k means.
+ *
+ * <p>K-Means is known for getting stuck into local optima, so you might want to run it multiple
+ * time and argmax on {@link KMeans#score(List)}
+ *
+ * @param k The number of points to return.
+ * @param inputData Input data.
+ * @return An array of k Means, each representing a centroid and data points that belong to it.
+ */
+ public List<Mean> predict(final int k, final float[][] inputData) {
+ checkDataSetSanity(inputData);
+ int dimension = inputData[0].length;
+
+ final ArrayList<Mean> means = new ArrayList<>();
+ for (int i = 0; i < k; i++) {
+ Mean m = new Mean(dimension);
+ for (int j = 0; j < dimension; j++) {
+ m.centroid[j] = randomState.nextFloat();
+ }
+ means.add(m);
}
- public KMeans(Random random) {
- this(random, 30 /* maxIterations */, 0.005f /* convergenceEpsilon */);
+ // Iterate until we converge or run out of iterations
+ boolean converged = false;
+ for (int i = 0; i < maxIterations; i++) {
+ converged = step(means, inputData);
+ if (converged) {
+ if (DEBUG) {
+ Log.d(TAG, "Converged at iteration: " + i);
+ }
+ break;
+ }
+ }
+ if (!converged && DEBUG) {
+ Log.d(TAG, "Did not converge");
}
- public KMeans(Random random, int maxIterations, float convergenceEpsilon) {
- mRandomState = random;
- mMaxIterations = maxIterations;
- mSqConvergenceEpsilon = convergenceEpsilon * convergenceEpsilon;
+ return means;
+ }
+
+ /**
+ * Score calculates the inertia between means. This can be considered as an E step of an EM
+ * algorithm.
+ *
+ * @param means Means to use when calculating score.
+ * @return The score
+ */
+ public static double score(List<Mean> means) {
+ double score = 0;
+ final int meansSize = means.size();
+ for (int i = 0; i < meansSize; i++) {
+ Mean mean = means.get(i);
+ for (int j = 0; j < meansSize; j++) {
+ Mean compareTo = means.get(j);
+ if (mean == compareTo) {
+ continue;
+ }
+ double distance = Math.sqrt(sqDistance(mean.centroid, compareTo.centroid));
+ score += distance;
+ }
+ }
+ return score;
+ }
+
+ /** */
+ @VisibleForTesting
+ public void checkDataSetSanity(float[][] inputData) {
+ if (inputData == null) {
+ throw new IllegalArgumentException("Data set is null.");
+ } else if (inputData.length == 0) {
+ throw new IllegalArgumentException("Data set is empty.");
+ } else if (inputData[0] == null) {
+ throw new IllegalArgumentException("Bad data set format.");
}
- /**
- * Runs k-means on the input data (X) trying to find k means.
- *
- * <p>K-Means is known for getting stuck into local optima, so you might want to run it multiple
- * time and argmax on {@link KMeans#score(List)}
- *
- * @param k The number of points to return.
- * @param inputData Input data.
- * @return An array of k Means, each representing a centroid and data points that belong to it.
- */
- public List<Mean> predict(final int k, final float[][] inputData) {
- checkDataSetSanity(inputData);
- int dimension = inputData[0].length;
+ final int dimension = inputData[0].length;
+ final int length = inputData.length;
+ for (int i = 1; i < length; i++) {
+ if (inputData[i] == null || inputData[i].length != dimension) {
+ throw new IllegalArgumentException("Bad data set format.");
+ }
+ }
+ }
- final ArrayList<Mean> means = new ArrayList<>();
- for (int i = 0; i < k; i++) {
- Mean m = new Mean(dimension);
- for (int j = 0; j < dimension; j++) {
- m.mCentroid[j] = mRandomState.nextFloat();
- }
- means.add(m);
- }
+ /**
+ * K-Means iteration.
+ *
+ * @param means Current means
+ * @param inputData Input data
+ * @return True if data set converged
+ */
+ private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
- // Iterate until we converge or run out of iterations
- boolean converged = false;
- for (int i = 0; i < mMaxIterations; i++) {
- converged = step(means, inputData);
- if (converged) {
- if (DEBUG) Log.d(TAG, "Converged at iteration: " + i);
- break;
- }
- }
- if (!converged && DEBUG) Log.d(TAG, "Did not converge");
-
- return means;
+ // Clean up the previous state because we need to compute
+ // which point belongs to each mean again.
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ mean.closestItems.clear();
+ }
+ for (int i = inputData.length - 1; i >= 0; i--) {
+ final float[] current = inputData[i];
+ final Mean nearest = nearestMean(current, means);
+ nearest.closestItems.add(current);
}
- /**
- * Score calculates the inertia between means. This can be considered as an E step of an EM
- * algorithm.
- *
- * @param means Means to use when calculating score.
- * @return The score
- */
- public static double score(@NonNull List<Mean> means) {
- double score = 0;
- final int meansSize = means.size();
- for (int i = 0; i < meansSize; i++) {
- Mean mean = means.get(i);
- for (int j = 0; j < meansSize; j++) {
- Mean compareTo = means.get(j);
- if (mean == compareTo) {
- continue;
- }
- double distance = Math.sqrt(sqDistance(mean.mCentroid, compareTo.mCentroid));
- score += distance;
- }
+ boolean converged = true;
+ // Move each mean towards the nearest data set points
+ for (int i = means.size() - 1; i >= 0; i--) {
+ final Mean mean = means.get(i);
+ if (mean.closestItems.isEmpty()) {
+ continue;
+ }
+
+ // Compute the new mean centroid:
+ // 1. Sum all all points
+ // 2. Average them
+ final float[] oldCentroid = mean.centroid;
+ mean.centroid = new float[oldCentroid.length];
+ for (int j = 0; j < mean.closestItems.size(); j++) {
+ // Update each centroid component
+ for (int p = 0; p < mean.centroid.length; p++) {
+ mean.centroid[p] += mean.closestItems.get(j)[p];
}
- return score;
+ }
+ for (int j = 0; j < mean.centroid.length; j++) {
+ mean.centroid[j] /= mean.closestItems.size();
+ }
+
+ // We converged if the centroid didn't move for any of the means.
+ if (sqDistance(oldCentroid, mean.centroid) > sqConvergenceEpsilon) {
+ converged = false;
+ }
+ }
+ return converged;
+ }
+
+ /** */
+ @VisibleForTesting
+ public static Mean nearestMean(float[] point, List<Mean> means) {
+ Mean nearest = null;
+ float nearestDistance = Float.MAX_VALUE;
+
+ final int meanCount = means.size();
+ for (int i = 0; i < meanCount; i++) {
+ Mean next = means.get(i);
+ // We don't need the sqrt when comparing distances in euclidean space
+ // because they exist on both sides of the equation and cancel each other out.
+ float nextDistance = sqDistance(point, next.centroid);
+ if (nextDistance < nearestDistance) {
+ nearest = next;
+ nearestDistance = nextDistance;
+ }
+ }
+ return nearest;
+ }
+
+ /** */
+ @VisibleForTesting
+ public static float sqDistance(float[] a, float[] b) {
+ float dist = 0;
+ final int length = a.length;
+ for (int i = 0; i < length; i++) {
+ dist += (a[i] - b[i]) * (a[i] - b[i]);
+ }
+ return dist;
+ }
+
+ /** Definition of a mean, contains a centroid and points on its cluster. */
+ public static class Mean {
+ float[] centroid;
+ final ArrayList<float[]> closestItems = new ArrayList<>();
+
+ public Mean(int dimension) {
+ centroid = new float[dimension];
}
- /** @param inputData */
- @VisibleForTesting
- public void checkDataSetSanity(float[][] inputData) {
- if (inputData == null) {
- throw new IllegalArgumentException("Data set is null.");
- } else if (inputData.length == 0) {
- throw new IllegalArgumentException("Data set is empty.");
- } else if (inputData[0] == null) {
- throw new IllegalArgumentException("Bad data set format.");
- }
-
- final int dimension = inputData[0].length;
- final int length = inputData.length;
- for (int i = 1; i < length; i++) {
- if (inputData[i] == null || inputData[i].length != dimension) {
- throw new IllegalArgumentException("Bad data set format.");
- }
- }
+ public Mean(float... centroid) {
+ this.centroid = centroid;
}
- /**
- * K-Means iteration.
- *
- * @param means Current means
- * @param inputData Input data
- * @return True if data set converged
- */
- private boolean step(final ArrayList<Mean> means, final float[][] inputData) {
-
- // Clean up the previous state because we need to compute
- // which point belongs to each mean again.
- for (int i = means.size() - 1; i >= 0; i--) {
- final Mean mean = means.get(i);
- mean.mClosestItems.clear();
- }
- for (int i = inputData.length - 1; i >= 0; i--) {
- final float[] current = inputData[i];
- final Mean nearest = nearestMean(current, means);
- nearest.mClosestItems.add(current);
- }
-
- boolean converged = true;
- // Move each mean towards the nearest data set points
- for (int i = means.size() - 1; i >= 0; i--) {
- final Mean mean = means.get(i);
- if (mean.mClosestItems.size() == 0) {
- continue;
- }
-
- // Compute the new mean centroid:
- // 1. Sum all all points
- // 2. Average them
- final float[] oldCentroid = mean.mCentroid;
- mean.mCentroid = new float[oldCentroid.length];
- for (int j = 0; j < mean.mClosestItems.size(); j++) {
- // Update each centroid component
- for (int p = 0; p < mean.mCentroid.length; p++) {
- mean.mCentroid[p] += mean.mClosestItems.get(j)[p];
- }
- }
- for (int j = 0; j < mean.mCentroid.length; j++) {
- mean.mCentroid[j] /= mean.mClosestItems.size();
- }
-
- // We converged if the centroid didn't move for any of the means.
- if (sqDistance(oldCentroid, mean.mCentroid) > mSqConvergenceEpsilon) {
- converged = false;
- }
- }
- return converged;
+ public float[] getCentroid() {
+ return centroid;
}
- /**
- * @param point
- * @param means
- * @return
- */
- @VisibleForTesting
- public static Mean nearestMean(float[] point, List<Mean> means) {
- Mean nearest = null;
- float nearestDistance = Float.MAX_VALUE;
-
- final int meanCount = means.size();
- for (int i = 0; i < meanCount; i++) {
- Mean next = means.get(i);
- // We don't need the sqrt when comparing distances in euclidean space
- // because they exist on both sides of the equation and cancel each other out.
- float nextDistance = sqDistance(point, next.mCentroid);
- if (nextDistance < nearestDistance) {
- nearest = next;
- nearestDistance = nextDistance;
- }
- }
- return nearest;
+ public List<float[]> getItems() {
+ return closestItems;
}
- /**
- * @param a
- * @param b
- * @return
- */
- @VisibleForTesting
- public static float sqDistance(float[] a, float[] b) {
- float dist = 0;
- final int length = a.length;
- for (int i = 0; i < length; i++) {
- dist += (a[i] - b[i]) * (a[i] - b[i]);
- }
- return dist;
+ @Override
+ public String toString() {
+ return "Mean(centroid: " + Arrays.toString(centroid) + ", size: " + closestItems.size() + ")";
}
-
- /** Definition of a mean, contains a centroid and points on its cluster. */
- public static class Mean {
- float[] mCentroid;
- final ArrayList<float[]> mClosestItems = new ArrayList<>();
-
- public Mean(int dimension) {
- mCentroid = new float[dimension];
- }
-
- public Mean(float... centroid) {
- mCentroid = centroid;
- }
-
- public float[] getCentroid() {
- return mCentroid;
- }
-
- public List<float[]> getItems() {
- return mClosestItems;
- }
-
- @Override
- public String toString() {
- return "Mean(centroid: "
- + Arrays.toString(mCentroid)
- + ", size: "
- + mClosestItems.size()
- + ")";
- }
- }
+ }
}
diff --git a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
index addc8e8..ea96633 100644
--- a/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
+++ b/java/src/com/android/textclassifier/utils/IndentingPrintWriter.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2019 The Android Open Source Project
+ * Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -11,13 +11,12 @@
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
- * limitations under the License
+ * limitations under the License.
*/
package com.android.textclassifier.utils;
-import androidx.core.util.Preconditions;
-
+import com.google.common.base.Preconditions;
import java.io.PrintWriter;
/**
@@ -26,52 +25,52 @@
* @see PrintWriter
*/
public final class IndentingPrintWriter {
- private static final String SINGLE_INDENT = " ";
+ private static final String SINGLE_INDENT = " ";
- private final PrintWriter mWriter;
- private StringBuilder mIndentBuilder = new StringBuilder();
- private String mCurrentIndent = "";
+ private final PrintWriter writer;
+ private final StringBuilder indentBuilder = new StringBuilder();
+ private String currentIndent = "";
- public IndentingPrintWriter(PrintWriter writer) {
- mWriter = Preconditions.checkNotNull(writer);
- }
+ public IndentingPrintWriter(PrintWriter writer) {
+ this.writer = Preconditions.checkNotNull(writer);
+ }
- /** Prints a string. */
- public IndentingPrintWriter println(String string) {
- mWriter.print(mCurrentIndent);
- mWriter.print(string);
- mWriter.println();
- return this;
- }
+ /** Prints a string. */
+ public IndentingPrintWriter println(String string) {
+ writer.print(currentIndent);
+ writer.print(string);
+ writer.println();
+ return this;
+ }
- /** Prints a empty line */
- public IndentingPrintWriter println() {
- mWriter.println();
- return this;
- }
+ /** Prints a empty line */
+ public IndentingPrintWriter println() {
+ writer.println();
+ return this;
+ }
- /** Increases indents for subsequent texts. */
- public IndentingPrintWriter increaseIndent() {
- mIndentBuilder.append(SINGLE_INDENT);
- mCurrentIndent = mIndentBuilder.toString();
- return this;
- }
+ /** Increases indents for subsequent texts. */
+ public IndentingPrintWriter increaseIndent() {
+ indentBuilder.append(SINGLE_INDENT);
+ currentIndent = indentBuilder.toString();
+ return this;
+ }
- /** Decreases indents for subsequent texts. */
- public IndentingPrintWriter decreaseIndent() {
- mIndentBuilder.delete(0, SINGLE_INDENT.length());
- mCurrentIndent = mIndentBuilder.toString();
- return this;
- }
+ /** Decreases indents for subsequent texts. */
+ public IndentingPrintWriter decreaseIndent() {
+ indentBuilder.delete(0, SINGLE_INDENT.length());
+ currentIndent = indentBuilder.toString();
+ return this;
+ }
- /** Prints a key-valued pair. */
- public IndentingPrintWriter printPair(String key, Object value) {
- println(String.format("%s=%s", key, String.valueOf(value)));
- return this;
- }
+ /** Prints a key-valued pair. */
+ public IndentingPrintWriter printPair(String key, Object value) {
+ println(String.format("%s=%s", key, String.valueOf(value)));
+ return this;
+ }
- /** Flushes the stream. */
- public void flush() {
- mWriter.flush();
- }
+ /** Flushes the stream. */
+ public void flush() {
+ writer.flush();
+ }
}