Using LangID to detect the language of the text and pass it to annotator

1. Use LangID to detect the language tags, and pass the result to
   the native side.
2. Have feature flags to turn off "detectLanguageTagsFromText" and
   foreign action for classification respectively.

Test: Select a non-english word, no more "define" action.
Test: atest frameworks/base/core/tests/coretests/src/android/view/textclassifier/

BUG: 123705564
BUG: 128413589

Change-Id: I6a68caca1e0709e63319907acbf8d776894f555b
diff --git a/core/java/android/view/textclassifier/ActionsSuggestionsHelper.java b/core/java/android/view/textclassifier/ActionsSuggestionsHelper.java
index ddbff7b..17edf5c 100644
--- a/core/java/android/view/textclassifier/ActionsSuggestionsHelper.java
+++ b/core/java/android/view/textclassifier/ActionsSuggestionsHelper.java
@@ -82,9 +82,12 @@
             long referenceTime = message.getReferenceTime() == null
                     ? 0
                     : message.getReferenceTime().toInstant().toEpochMilli();
+            String timeZone = message.getReferenceTime() == null
+                    ? null
+                    : message.getReferenceTime().getZone().getId();
             nativeMessages.push(new ActionsSuggestionsModel.ConversationMessage(
                     personEncoder.encode(message.getAuthor()),
-                    message.getText().toString(), referenceTime,
+                    message.getText().toString(), referenceTime, timeZone,
                     languageDetector.apply(message.getText())));
         }
         return nativeMessages.toArray(
diff --git a/core/java/android/view/textclassifier/TextClassificationConstants.java b/core/java/android/view/textclassifier/TextClassificationConstants.java
index 125b0d3..38e72e3 100644
--- a/core/java/android/view/textclassifier/TextClassificationConstants.java
+++ b/core/java/android/view/textclassifier/TextClassificationConstants.java
@@ -48,6 +48,8 @@
  * notification_conversation_action_types_default   (String[])
  * lang_id_threshold_override                       (float)
  * template_intent_factory_enabled                  (boolean)
+ * translate_in_classification_enabled              (boolean)
+ * detect_languages_from_text_enabled               (boolean)
  * </pre>
  *
  * <p>
@@ -139,7 +141,7 @@
     private static final String NOTIFICATION_CONVERSATION_ACTION_TYPES_DEFAULT =
             "notification_conversation_action_types_default";
     /**
-     * Threshold in classifyText to consider a text is in a foreign language.
+     * Threshold to accept a suggested language from LangID model.
      */
     private static final String LANG_ID_THRESHOLD_OVERRIDE = "lang_id_threshold_override";
     /**
@@ -147,6 +149,18 @@
      */
     private static final String TEMPLATE_INTENT_FACTORY_ENABLED = "template_intent_factory_enabled";
 
+    /**
+     * Whether to enable "translate" action in classifyText.
+     */
+    private static final String TRANSLATE_IN_CLASSIFICATION_ENABLED =
+            "translate_in_classification_enabled";
+    /**
+     * Whether to detect the languages of the text in request by using langId for the native
+     * model.
+     */
+    private static final String DETECT_LANGUAGES_FROM_TEXT_ENABLED =
+            "detect_languages_from_text_enabled";
+
     private static final boolean LOCAL_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
     private static final boolean SYSTEM_TEXT_CLASSIFIER_ENABLED_DEFAULT = true;
     private static final boolean MODEL_DARK_LAUNCH_ENABLED_DEFAULT = false;
@@ -183,11 +197,13 @@
     /**
      * < 0  : Not set. Use value from LangId model.
      * 0 - 1: Override value in LangId model.
-     * > 1  : Effectively turns off the foreign language detection. Scores should never be > 1.
+     *
      * @see EntityConfidence
      */
     private static final float LANG_ID_THRESHOLD_OVERRIDE_DEFAULT = -1f;
     private static final boolean TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT = true;
+    private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
+    private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
 
     private final boolean mSystemTextClassifierEnabled;
     private final boolean mLocalTextClassifierEnabled;
@@ -207,6 +223,8 @@
     private final List<String> mNotificationConversationActionTypesDefault;
     private final float mLangIdThresholdOverride;
     private final boolean mTemplateIntentFactoryEnabled;
+    private final boolean mTranslateInClassificationEnabled;
+    private final boolean mDetectLanguagesFromTextEnabled;
 
     private TextClassificationConstants(@Nullable String settings) {
         ConfigParser configParser = new ConfigParser(settings);
@@ -280,6 +298,10 @@
         mTemplateIntentFactoryEnabled = configParser.getBoolean(
                 TEMPLATE_INTENT_FACTORY_ENABLED,
                 TEMPLATE_INTENT_FACTORY_ENABLED_DEFAULT);
+        mTranslateInClassificationEnabled = configParser.getBoolean(
+                TRANSLATE_IN_CLASSIFICATION_ENABLED, TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT);
+        mDetectLanguagesFromTextEnabled = configParser.getBoolean(
+                DETECT_LANGUAGES_FROM_TEXT_ENABLED, DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT);
     }
 
     /** Load from a settings string. */
@@ -359,6 +381,14 @@
         return mTemplateIntentFactoryEnabled;
     }
 
+    public boolean isTranslateInClassificationEnabled() {
+        return mTranslateInClassificationEnabled;
+    }
+
+    public boolean isDetectLanguagesFromTextEnabled() {
+        return mDetectLanguagesFromTextEnabled;
+    }
+
     private static List<String> parseStringList(String listStr) {
         return Collections.unmodifiableList(Arrays.asList(listStr.split(STRING_LIST_DELIMITER)));
     }
@@ -385,6 +415,8 @@
                 mNotificationConversationActionTypesDefault);
         pw.printPair("getLangIdThresholdOverride", mLangIdThresholdOverride);
         pw.printPair("isTemplateIntentFactoryEnabled", mTemplateIntentFactoryEnabled);
+        pw.printPair("isTranslateInClassificationEnabled", mTranslateInClassificationEnabled);
+        pw.printPair("isDetectLanguageFromTextEnabled", mDetectLanguagesFromTextEnabled);
         pw.decreaseIndent();
         pw.println();
     }
diff --git a/core/java/android/view/textclassifier/TextClassifierImpl.java b/core/java/android/view/textclassifier/TextClassifierImpl.java
index 35cd678..0f3a8cf 100644
--- a/core/java/android/view/textclassifier/TextClassifierImpl.java
+++ b/core/java/android/view/textclassifier/TextClassifierImpl.java
@@ -164,6 +164,7 @@
             if (string.length() > 0
                     && rangeLength <= mSettings.getSuggestSelectionMaxRangeLength()) {
                 final String localesString = concatenateLocales(request.getDefaultLocales());
+                final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
                 final ZonedDateTime refTime = ZonedDateTime.now();
                 final AnnotatorModel annotatorImpl =
                         getAnnotatorImpl(request.getDefaultLocales());
@@ -175,7 +176,7 @@
                 } else {
                     final int[] startEnd = annotatorImpl.suggestSelection(
                             string, request.getStartIndex(), request.getEndIndex(),
-                            new AnnotatorModel.SelectionOptions(localesString));
+                            new AnnotatorModel.SelectionOptions(localesString, detectLanguageTags));
                     start = startEnd[0];
                     end = startEnd[1];
                 }
@@ -189,7 +190,8 @@
                                     new AnnotatorModel.ClassificationOptions(
                                             refTime.toInstant().toEpochMilli(),
                                             refTime.getZone().getId(),
-                                            localesString),
+                                            localesString,
+                                            detectLanguageTags),
                                     // Passing null here to suppress intent generation
                                     // TODO: Use an explicit flag to suppress it.
                                     /* appContext */ null,
@@ -227,6 +229,7 @@
             final String string = request.getText().toString();
             if (string.length() > 0 && rangeLength <= mSettings.getClassifyTextMaxRangeLength()) {
                 final String localesString = concatenateLocales(request.getDefaultLocales());
+                final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
                 final ZonedDateTime refTime = request.getReferenceTime() != null
                         ? request.getReferenceTime() : ZonedDateTime.now();
                 final AnnotatorModel.ClassificationResult[] results =
@@ -236,9 +239,10 @@
                                         new AnnotatorModel.ClassificationOptions(
                                                 refTime.toInstant().toEpochMilli(),
                                                 refTime.getZone().getId(),
-                                                localesString),
+                                                localesString,
+                                                detectLanguageTags),
                                         mContext,
-                                        getResourceLocaleString()
+                                        getResourceLocalesString()
                                 );
                 if (results.length > 0) {
                     return createClassificationResult(
@@ -276,6 +280,8 @@
                     ? request.getEntityConfig().resolveEntityListModifications(
                     getEntitiesForHints(request.getEntityConfig().getHints()))
                     : mSettings.getEntityListDefault();
+            final String localesString = concatenateLocales(request.getDefaultLocales());
+            final String detectLanguageTags = detectLanguageTagsFromText(request.getText());
             final AnnotatorModel annotatorImpl =
                     getAnnotatorImpl(request.getDefaultLocales());
             final AnnotatorModel.AnnotatedSpan[] annotations =
@@ -284,7 +290,8 @@
                             new AnnotatorModel.AnnotationOptions(
                                     refTime.toInstant().toEpochMilli(),
                                     refTime.getZone().getId(),
-                                    concatenateLocales(request.getDefaultLocales())));
+                                    localesString,
+                                    detectLanguageTags));
             for (AnnotatorModel.AnnotatedSpan span : annotations) {
                 final AnnotatorModel.ClassificationResult[] results =
                         span.getClassification();
@@ -386,8 +393,8 @@
                 return mFallback.suggestConversationActions(request);
             }
             ActionsSuggestionsModel.ConversationMessage[] nativeMessages =
-                    ActionsSuggestionsHelper.toNativeMessages(request.getConversation(),
-                            this::detectLanguageTagsFromText);
+                    ActionsSuggestionsHelper.toNativeMessages(
+                            request.getConversation(), this::detectLanguageTagsFromText);
             if (nativeMessages.length == 0) {
                 return mFallback.suggestConversationActions(request);
             }
@@ -399,7 +406,7 @@
                             nativeConversation,
                             null,
                             mContext,
-                            getResourceLocaleString());
+                            getResourceLocalesString());
             return createConversationActionResult(request, nativeSuggestions);
         } catch (Throwable t) {
             // Avoid throwing from this method. Log the error.
@@ -463,19 +470,28 @@
 
     @Nullable
     private String detectLanguageTagsFromText(CharSequence text) {
+        if (!mSettings.isDetectLanguagesFromTextEnabled()) {
+            return null;
+        }
+        final float threshold = getLangIdThreshold();
+        if (threshold < 0 || threshold > 1) {
+            Log.w(LOG_TAG,
+                    "[detectLanguageTagsFromText] unexpected threshold is found: " + threshold);
+            return null;
+        }
         TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
         TextLanguage textLanguage = detectLanguage(request);
         int localeHypothesisCount = textLanguage.getLocaleHypothesisCount();
         List<String> languageTags = new ArrayList<>();
         for (int i = 0; i < localeHypothesisCount; i++) {
             ULocale locale = textLanguage.getLocale(i);
-            if (textLanguage.getConfidenceScore(locale) < getForeignLanguageThreshold()) {
+            if (textLanguage.getConfidenceScore(locale) < threshold) {
                 break;
             }
             languageTags.add(locale.toLanguageTag());
         }
         if (languageTags.isEmpty()) {
-            return LocaleList.getDefault().toLanguageTags();
+            return null;
         }
         return String.join(",", languageTags);
     }
@@ -644,10 +660,14 @@
     // TODO: Consider making this public API.
     @Nullable
     private Bundle detectForeignLanguage(String text) {
+        if (!mSettings.isTranslateInClassificationEnabled()) {
+            return null;
+        }
         try {
-            final float threshold = getForeignLanguageThreshold();
-            if (threshold > 1) {
-                Log.v(LOG_TAG, "Foreign language detection disabled.");
+            final float threshold = getLangIdThreshold();
+            if (threshold < 0 || threshold > 1) {
+                Log.w(LOG_TAG,
+                        "[detectForeignLanguage] unexpected threshold is found: " + threshold);
                 return null;
             }
 
@@ -686,11 +706,11 @@
         return null;
     }
 
-    private float getForeignLanguageThreshold() {
+    private float getLangIdThreshold() {
         try {
             return mSettings.getLangIdThresholdOverride() >= 0
                     ? mSettings.getLangIdThresholdOverride()
-                    : getLangIdImpl().getTranslateThreshold();
+                    : getLangIdImpl().getLangIdThreshold();
         } catch (FileNotFoundException e) {
             final float defaultThreshold = 0.5f;
             Log.v(LOG_TAG, "Using default foreign language threshold: " + defaultThreshold);
@@ -746,15 +766,14 @@
     }
 
     /**
-     * Returns the locale string for the current resources configuration.
+     * Returns the locales string for the current resources configuration.
      */
-    private String getResourceLocaleString() {
-        // TODO: Pass the locale list once it is supported in native side.
+    private String getResourceLocalesString() {
         try {
-            return mContext.getResources().getConfiguration().getLocales().get(0).toLanguageTag();
+            return mContext.getResources().getConfiguration().getLocales().toLanguageTags();
         } catch (NullPointerException e) {
             // NPE is unexpected. Erring on the side of caution.
-            return LocaleList.getDefault().get(0).toLanguageTag();
+            return LocaleList.getDefault().toLanguageTags();
         }
     }
 }