Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2017 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | package android.view.textclassifier; |
| 18 | |
| 19 | import android.annotation.FloatRange; |
| 20 | import android.annotation.NonNull; |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 21 | import android.os.Parcel; |
| 22 | import android.os.Parcelable; |
Jan Althaus | bbe43df | 2017-11-30 15:01:40 +0100 | [diff] [blame] | 23 | import android.util.ArrayMap; |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 24 | |
| 25 | import com.android.internal.util.Preconditions; |
| 26 | |
| 27 | import java.util.ArrayList; |
| 28 | import java.util.Collections; |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 29 | import java.util.List; |
| 30 | import java.util.Map; |
| 31 | |
| 32 | /** |
| 33 | * Helper object for setting and getting entity scores for classified text. |
| 34 | * |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 35 | * @hide |
| 36 | */ |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 37 | final class EntityConfidence implements Parcelable { |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 38 | |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 39 | private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>(); |
| 40 | private final ArrayList<String> mSortedEntities = new ArrayList<>(); |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 41 | |
| 42 | EntityConfidence() {} |
| 43 | |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 44 | EntityConfidence(@NonNull EntityConfidence source) { |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 45 | Preconditions.checkNotNull(source); |
| 46 | mEntityConfidence.putAll(source.mEntityConfidence); |
Jan Althaus | bbe43df | 2017-11-30 15:01:40 +0100 | [diff] [blame] | 47 | mSortedEntities.addAll(source.mSortedEntities); |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 48 | } |
| 49 | |
| 50 | /** |
Jan Althaus | bbe43df | 2017-11-30 15:01:40 +0100 | [diff] [blame] | 51 | * Constructs an EntityConfidence from a map of entity to confidence. |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 52 | * |
Jan Althaus | bbe43df | 2017-11-30 15:01:40 +0100 | [diff] [blame] | 53 | * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1. |
| 54 | * |
| 55 | * @param source a map from entity to a confidence value in the range 0 (low confidence) to |
| 56 | * 1 (high confidence). |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 57 | */ |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 58 | EntityConfidence(@NonNull Map<String, Float> source) { |
Jan Althaus | bbe43df | 2017-11-30 15:01:40 +0100 | [diff] [blame] | 59 | Preconditions.checkNotNull(source); |
| 60 | |
| 61 | // Prune non-existent entities and clamp to 1. |
| 62 | mEntityConfidence.ensureCapacity(source.size()); |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 63 | for (Map.Entry<String, Float> it : source.entrySet()) { |
Jan Althaus | bbe43df | 2017-11-30 15:01:40 +0100 | [diff] [blame] | 64 | if (it.getValue() <= 0) continue; |
| 65 | mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue())); |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 66 | } |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 67 | resetSortedEntitiesFromMap(); |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 68 | } |
| 69 | |
| 70 | /** |
| 71 | * Returns an immutable list of entities found in the classified text ordered from |
| 72 | * high confidence to low confidence. |
| 73 | */ |
| 74 | @NonNull |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 75 | public List<String> getEntities() { |
Jan Althaus | bbe43df | 2017-11-30 15:01:40 +0100 | [diff] [blame] | 76 | return Collections.unmodifiableList(mSortedEntities); |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 77 | } |
| 78 | |
| 79 | /** |
| 80 | * Returns the confidence score for the specified entity. The value ranges from |
| 81 | * 0 (low confidence) to 1 (high confidence). 0 indicates that the entity was not found for the |
| 82 | * classified text. |
| 83 | */ |
| 84 | @FloatRange(from = 0.0, to = 1.0) |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 85 | public float getConfidenceScore(String entity) { |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 86 | if (mEntityConfidence.containsKey(entity)) { |
| 87 | return mEntityConfidence.get(entity); |
| 88 | } |
| 89 | return 0; |
| 90 | } |
| 91 | |
| 92 | @Override |
| 93 | public String toString() { |
| 94 | return mEntityConfidence.toString(); |
| 95 | } |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 96 | |
| 97 | @Override |
| 98 | public int describeContents() { |
| 99 | return 0; |
| 100 | } |
| 101 | |
| 102 | @Override |
| 103 | public void writeToParcel(Parcel dest, int flags) { |
| 104 | dest.writeInt(mEntityConfidence.size()); |
| 105 | for (Map.Entry<String, Float> entry : mEntityConfidence.entrySet()) { |
| 106 | dest.writeString(entry.getKey()); |
| 107 | dest.writeFloat(entry.getValue()); |
| 108 | } |
| 109 | } |
| 110 | |
Jeff Sharkey | 9e8f83d | 2019-02-28 12:06:45 -0700 | [diff] [blame] | 111 | public static final @android.annotation.NonNull Parcelable.Creator<EntityConfidence> CREATOR = |
Jan Althaus | 0d9fbb9 | 2017-11-28 12:19:33 +0100 | [diff] [blame] | 112 | new Parcelable.Creator<EntityConfidence>() { |
| 113 | @Override |
| 114 | public EntityConfidence createFromParcel(Parcel in) { |
| 115 | return new EntityConfidence(in); |
| 116 | } |
| 117 | |
| 118 | @Override |
| 119 | public EntityConfidence[] newArray(int size) { |
| 120 | return new EntityConfidence[size]; |
| 121 | } |
| 122 | }; |
| 123 | |
| 124 | private EntityConfidence(Parcel in) { |
| 125 | final int numEntities = in.readInt(); |
| 126 | mEntityConfidence.ensureCapacity(numEntities); |
| 127 | for (int i = 0; i < numEntities; ++i) { |
| 128 | mEntityConfidence.put(in.readString(), in.readFloat()); |
| 129 | } |
| 130 | resetSortedEntitiesFromMap(); |
| 131 | } |
| 132 | |
| 133 | private void resetSortedEntitiesFromMap() { |
| 134 | mSortedEntities.clear(); |
| 135 | mSortedEntities.ensureCapacity(mEntityConfidence.size()); |
| 136 | mSortedEntities.addAll(mEntityConfidence.keySet()); |
| 137 | mSortedEntities.sort((e1, e2) -> { |
| 138 | float score1 = mEntityConfidence.get(e1); |
| 139 | float score2 = mEntityConfidence.get(e2); |
| 140 | return Float.compare(score2, score1); |
| 141 | }); |
| 142 | } |
Abodunrinwa Toki | f001fef | 2017-01-04 23:51:42 +0000 | [diff] [blame] | 143 | } |