blob: 69a59a5b4b3686ca0e4ffdf901fa6b20f22c6bab [file] [log] [blame]
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +00001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package android.view.textclassifier;
18
19import android.annotation.FloatRange;
20import android.annotation.NonNull;
Jan Althaus0d9fbb92017-11-28 12:19:33 +010021import android.os.Parcel;
22import android.os.Parcelable;
Jan Althausbbe43df2017-11-30 15:01:40 +010023import android.util.ArrayMap;
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000024
25import com.android.internal.util.Preconditions;
26
27import java.util.ArrayList;
28import java.util.Collections;
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000029import java.util.List;
30import java.util.Map;
31
32/**
33 * Helper object for setting and getting entity scores for classified text.
34 *
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000035 * @hide
36 */
Jan Althaus0d9fbb92017-11-28 12:19:33 +010037final class EntityConfidence implements Parcelable {
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000038
Jan Althaus0d9fbb92017-11-28 12:19:33 +010039 private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>();
40 private final ArrayList<String> mSortedEntities = new ArrayList<>();
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000041
42 EntityConfidence() {}
43
Jan Althaus0d9fbb92017-11-28 12:19:33 +010044 EntityConfidence(@NonNull EntityConfidence source) {
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000045 Preconditions.checkNotNull(source);
46 mEntityConfidence.putAll(source.mEntityConfidence);
Jan Althausbbe43df2017-11-30 15:01:40 +010047 mSortedEntities.addAll(source.mSortedEntities);
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000048 }
49
50 /**
Jan Althausbbe43df2017-11-30 15:01:40 +010051 * Constructs an EntityConfidence from a map of entity to confidence.
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000052 *
Jan Althausbbe43df2017-11-30 15:01:40 +010053 * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
54 *
55 * @param source a map from entity to a confidence value in the range 0 (low confidence) to
56 * 1 (high confidence).
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000057 */
Jan Althaus0d9fbb92017-11-28 12:19:33 +010058 EntityConfidence(@NonNull Map<String, Float> source) {
Jan Althausbbe43df2017-11-30 15:01:40 +010059 Preconditions.checkNotNull(source);
60
61 // Prune non-existent entities and clamp to 1.
62 mEntityConfidence.ensureCapacity(source.size());
Jan Althaus0d9fbb92017-11-28 12:19:33 +010063 for (Map.Entry<String, Float> it : source.entrySet()) {
Jan Althausbbe43df2017-11-30 15:01:40 +010064 if (it.getValue() <= 0) continue;
65 mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000066 }
Jan Althaus0d9fbb92017-11-28 12:19:33 +010067 resetSortedEntitiesFromMap();
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000068 }
69
70 /**
71 * Returns an immutable list of entities found in the classified text ordered from
72 * high confidence to low confidence.
73 */
74 @NonNull
Jan Althaus0d9fbb92017-11-28 12:19:33 +010075 public List<String> getEntities() {
Jan Althausbbe43df2017-11-30 15:01:40 +010076 return Collections.unmodifiableList(mSortedEntities);
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000077 }
78
79 /**
80 * Returns the confidence score for the specified entity. The value ranges from
81 * 0 (low confidence) to 1 (high confidence). 0 indicates that the entity was not found for the
82 * classified text.
83 */
84 @FloatRange(from = 0.0, to = 1.0)
Jan Althaus0d9fbb92017-11-28 12:19:33 +010085 public float getConfidenceScore(String entity) {
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +000086 if (mEntityConfidence.containsKey(entity)) {
87 return mEntityConfidence.get(entity);
88 }
89 return 0;
90 }
91
92 @Override
93 public String toString() {
94 return mEntityConfidence.toString();
95 }
Jan Althaus0d9fbb92017-11-28 12:19:33 +010096
97 @Override
98 public int describeContents() {
99 return 0;
100 }
101
102 @Override
103 public void writeToParcel(Parcel dest, int flags) {
104 dest.writeInt(mEntityConfidence.size());
105 for (Map.Entry<String, Float> entry : mEntityConfidence.entrySet()) {
106 dest.writeString(entry.getKey());
107 dest.writeFloat(entry.getValue());
108 }
109 }
110
111 public static final Parcelable.Creator<EntityConfidence> CREATOR =
112 new Parcelable.Creator<EntityConfidence>() {
113 @Override
114 public EntityConfidence createFromParcel(Parcel in) {
115 return new EntityConfidence(in);
116 }
117
118 @Override
119 public EntityConfidence[] newArray(int size) {
120 return new EntityConfidence[size];
121 }
122 };
123
124 private EntityConfidence(Parcel in) {
125 final int numEntities = in.readInt();
126 mEntityConfidence.ensureCapacity(numEntities);
127 for (int i = 0; i < numEntities; ++i) {
128 mEntityConfidence.put(in.readString(), in.readFloat());
129 }
130 resetSortedEntitiesFromMap();
131 }
132
133 private void resetSortedEntitiesFromMap() {
134 mSortedEntities.clear();
135 mSortedEntities.ensureCapacity(mEntityConfidence.size());
136 mSortedEntities.addAll(mEntityConfidence.keySet());
137 mSortedEntities.sort((e1, e2) -> {
138 float score1 = mEntityConfidence.get(e1);
139 float score2 = mEntityConfidence.get(e2);
140 return Float.compare(score2, score1);
141 });
142 }
Abodunrinwa Tokif001fef2017-01-04 23:51:42 +0000143}