Merge "Storage refactor for EntityConfidence"
diff --git a/core/java/android/view/textclassifier/EntityConfidence.java b/core/java/android/view/textclassifier/EntityConfidence.java
index 0589d204..19660d9 100644
--- a/core/java/android/view/textclassifier/EntityConfidence.java
+++ b/core/java/android/view/textclassifier/EntityConfidence.java
@@ -18,13 +18,12 @@
 
 import android.annotation.FloatRange;
 import android.annotation.NonNull;
+import android.util.ArrayMap;
 
 import com.android.internal.util.Preconditions;
 
 import java.util.ArrayList;
 import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -36,42 +35,43 @@
  */
 final class EntityConfidence<T> {
 
-    private final Map<T, Float> mEntityConfidence = new HashMap<>();
-
-    private final Comparator<T> mEntityComparator = (e1, e2) -> {
-        float score1 = mEntityConfidence.get(e1);
-        float score2 = mEntityConfidence.get(e2);
-        if (score1 > score2) {
-            return -1;
-        }
-        if (score1 < score2) {
-            return 1;
-        }
-        return 0;
-    };
+    private final ArrayMap<T, Float> mEntityConfidence = new ArrayMap<>();
+    private final ArrayList<T> mSortedEntities = new ArrayList<>();
 
     EntityConfidence() {}
 
     EntityConfidence(@NonNull EntityConfidence<T> source) {
         Preconditions.checkNotNull(source);
         mEntityConfidence.putAll(source.mEntityConfidence);
+        mSortedEntities.addAll(source.mSortedEntities);
     }
 
     /**
-     * Sets an entity type for the classified text and assigns a confidence score.
+     * Constructs an EntityConfidence from a map of entity to confidence.
      *
-     * @param confidenceScore a value from 0 (low confidence) to 1 (high confidence).
-     *      0 implies the entity does not exist for the classified text.
-     *      Values greater than 1 are clamped to 1.
+     * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
+     *
+     * @param source a map from entity to a confidence value in the range 0 (low confidence) to
+     *               1 (high confidence).
      */
-    public void setEntityType(
-            @NonNull T type, @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
-        Preconditions.checkNotNull(type);
-        if (confidenceScore > 0) {
-            mEntityConfidence.put(type, Math.min(1, confidenceScore));
-        } else {
-            mEntityConfidence.remove(type);
+    EntityConfidence(@NonNull Map<T, Float> source) {
+        Preconditions.checkNotNull(source);
+
+        // Prune non-existent entities and clamp to 1.
+        mEntityConfidence.ensureCapacity(source.size());
+        for (Map.Entry<T, Float> it : source.entrySet()) {
+            if (it.getValue() <= 0) continue;
+            mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
         }
+
+        // Create a list of entities sorted by decreasing confidence for getEntities().
+        mSortedEntities.ensureCapacity(mEntityConfidence.size());
+        mSortedEntities.addAll(mEntityConfidence.keySet());
+        mSortedEntities.sort((e1, e2) -> {
+            float score1 = mEntityConfidence.get(e1);
+            float score2 = mEntityConfidence.get(e2);
+            return Float.compare(score2, score1);
+        });
     }
 
     /**
@@ -80,10 +80,7 @@
      */
     @NonNull
     public List<T> getEntities() {
-        List<T> entities = new ArrayList<>(mEntityConfidence.size());
-        entities.addAll(mEntityConfidence.keySet());
-        entities.sort(mEntityComparator);
-        return Collections.unmodifiableList(entities);
+        return Collections.unmodifiableList(mSortedEntities);
     }
 
     /**
diff --git a/core/java/android/view/textclassifier/TextClassification.java b/core/java/android/view/textclassifier/TextClassification.java
index f675c35..8916323 100644
--- a/core/java/android/view/textclassifier/TextClassification.java
+++ b/core/java/android/view/textclassifier/TextClassification.java
@@ -24,6 +24,7 @@
 import android.content.Intent;
 import android.graphics.drawable.Drawable;
 import android.os.LocaleList;
+import android.util.ArrayMap;
 import android.view.View.OnClickListener;
 import android.view.textclassifier.TextClassifier.EntityType;
 
@@ -32,6 +33,7 @@
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Locale;
+import java.util.Map;
 
 /**
  * Information for generating a widget to handle classified text.
@@ -95,7 +97,6 @@
     @NonNull private final List<Intent> mIntents;
     @NonNull private final List<OnClickListener> mOnClickListeners;
     @NonNull private final EntityConfidence<String> mEntityConfidence;
-    @NonNull private final List<String> mEntities;
     private int mLogType;
     @NonNull private final String mVersionInfo;
 
@@ -105,7 +106,7 @@
             @NonNull List<String> labels,
             @NonNull List<Intent> intents,
             @NonNull List<OnClickListener> onClickListeners,
-            @NonNull EntityConfidence<String> entityConfidence,
+            @NonNull Map<String, Float> entityConfidence,
             int logType,
             @NonNull String versionInfo) {
         Preconditions.checkArgument(labels.size() == intents.size());
@@ -117,7 +118,6 @@
         mIntents = intents;
         mOnClickListeners = onClickListeners;
         mEntityConfidence = new EntityConfidence<>(entityConfidence);
-        mEntities = mEntityConfidence.getEntities();
         mLogType = logType;
         mVersionInfo = versionInfo;
     }
@@ -135,7 +135,7 @@
      */
     @IntRange(from = 0)
     public int getEntityCount() {
-        return mEntities.size();
+        return mEntityConfidence.getEntities().size();
     }
 
     /**
@@ -147,7 +147,7 @@
      */
     @NonNull
     public @EntityType String getEntity(int index) {
-        return mEntities.get(index);
+        return mEntityConfidence.getEntities().get(index);
     }
 
     /**
@@ -311,8 +311,7 @@
         @NonNull private final List<String> mLabels = new ArrayList<>();
         @NonNull private final List<Intent> mIntents = new ArrayList<>();
         @NonNull private final List<OnClickListener> mOnClickListeners = new ArrayList<>();
-        @NonNull private final EntityConfidence<String> mEntityConfidence =
-                new EntityConfidence<>();
+        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
         private int mLogType;
         @NonNull private String mVersionInfo = "";
 
@@ -334,7 +333,7 @@
         public Builder setEntityType(
                 @NonNull @EntityType String type,
                 @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
-            mEntityConfidence.setEntityType(type, confidenceScore);
+            mEntityConfidence.put(type, confidenceScore);
             return this;
         }
 
diff --git a/core/java/android/view/textclassifier/TextLinks.java b/core/java/android/view/textclassifier/TextLinks.java
index 76748d2..0e039e3 100644
--- a/core/java/android/view/textclassifier/TextLinks.java
+++ b/core/java/android/view/textclassifier/TextLinks.java
@@ -103,11 +103,7 @@
             mOriginalText = originalText;
             mStart = start;
             mEnd = end;
-            mEntityScores = new EntityConfidence<>();
-
-            for (Map.Entry<String, Float> entry : entityScores.entrySet()) {
-                mEntityScores.setEntityType(entry.getKey(), entry.getValue());
-            }
+            mEntityScores = new EntityConfidence<>(entityScores);
         }
 
         /**
diff --git a/core/java/android/view/textclassifier/TextSelection.java b/core/java/android/view/textclassifier/TextSelection.java
index 480b27a..ced4018 100644
--- a/core/java/android/view/textclassifier/TextSelection.java
+++ b/core/java/android/view/textclassifier/TextSelection.java
@@ -21,12 +21,13 @@
 import android.annotation.NonNull;
 import android.annotation.Nullable;
 import android.os.LocaleList;
+import android.util.ArrayMap;
 import android.view.textclassifier.TextClassifier.EntityType;
 
 import com.android.internal.util.Preconditions;
 
-import java.util.List;
 import java.util.Locale;
+import java.util.Map;
 
 /**
  * Information about where text selection should be.
@@ -36,7 +37,6 @@
     private final int mStartIndex;
     private final int mEndIndex;
     @NonNull private final EntityConfidence<String> mEntityConfidence;
-    @NonNull private final List<String> mEntities;
     @NonNull private final String mLogSource;
     @NonNull private final String mVersionInfo;
 
@@ -46,7 +46,6 @@
         mStartIndex = startIndex;
         mEndIndex = endIndex;
         mEntityConfidence = new EntityConfidence<>(entityConfidence);
-        mEntities = mEntityConfidence.getEntities();
         mLogSource = logSource;
         mVersionInfo = versionInfo;
     }
@@ -70,7 +69,7 @@
      */
     @IntRange(from = 0)
     public int getEntityCount() {
-        return mEntities.size();
+        return mEntityConfidence.getEntities().size();
     }
 
     /**
@@ -82,7 +81,7 @@
      */
     @NonNull
     public @EntityType String getEntity(int index) {
-        return mEntities.get(index);
+        return mEntityConfidence.getEntities().get(index);
     }
 
     /**
@@ -126,8 +125,7 @@
 
         private final int mStartIndex;
         private final int mEndIndex;
-        @NonNull private final EntityConfidence<String> mEntityConfidence =
-                new EntityConfidence<>();
+        @NonNull private final Map<String, Float> mEntityConfidence = new ArrayMap<>();
         @NonNull private String mLogSource = "";
         @NonNull private String mVersionInfo = "";
 
@@ -154,7 +152,7 @@
         public Builder setEntityType(
                 @NonNull @EntityType String type,
                 @FloatRange(from = 0.0, to = 1.0) float confidenceScore) {
-            mEntityConfidence.setEntityType(type, confidenceScore);
+            mEntityConfidence.put(type, confidenceScore);
             return this;
         }
 
@@ -181,7 +179,8 @@
          */
         public TextSelection build() {
             return new TextSelection(
-                    mStartIndex, mEndIndex, mEntityConfidence, mLogSource, mVersionInfo);
+                    mStartIndex, mEndIndex, new EntityConfidence<>(mEntityConfidence),  mLogSource,
+                    mVersionInfo);
         }
     }