Add AppPredictionServiceResolverComparator
This will sort the share activities based on the APS sorting.
We add a constructor for ResolverListController which takes an
AbstractResolverComparator, so that ChooserActivity may pass in
the APS comparator if it is enabled and available.
Test: Manually tested on APS sorter that did no sorting.
Test: atest frameworks/base/core/tests/coretests/src/com/android/internal/app
Bug: 129014961
Change-Id: I542254ffb0debad45bcd8d5073cc3f3e1bafc616
Signed-off-by: George Hodulik <georgehodulik@google.com>
diff --git a/core/java/com/android/internal/app/AbstractResolverComparator.java b/core/java/com/android/internal/app/AbstractResolverComparator.java
index e091aac..b7276a0 100644
--- a/core/java/com/android/internal/app/AbstractResolverComparator.java
+++ b/core/java/com/android/internal/app/AbstractResolverComparator.java
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package com.android.internal.app;
import android.app.usage.UsageStatsManager;
@@ -20,7 +36,7 @@
private static final int NUM_OF_TOP_ANNOTATIONS_TO_USE = 3;
- protected AfterCompute mAfterCompute;
+ private AfterCompute mAfterCompute;
protected final PackageManager mPm;
protected final UsageStatsManager mUsm;
protected String[] mAnnotations;
@@ -72,6 +88,13 @@
mAfterCompute = afterCompute;
}
+ protected final void afterCompute() {
+ final AfterCompute afterCompute = mAfterCompute;
+ if (afterCompute != null) {
+ afterCompute.afterCompute();
+ }
+ }
+
@Override
public final int compare(ResolvedComponentInfo lhsp, ResolvedComponentInfo rhsp) {
final ResolveInfo lhs = lhsp.getResolveInfoAt(0);
diff --git a/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java b/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java
new file mode 100644
index 0000000..cb44c67
--- /dev/null
+++ b/core/java/com/android/internal/app/AppPredictionServiceResolverComparator.java
@@ -0,0 +1,119 @@
+/*
+ * Copyright 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.internal.app;
+
+import static android.app.prediction.AppTargetEvent.ACTION_LAUNCH;
+
+import android.app.prediction.AppPredictor;
+import android.app.prediction.AppTarget;
+import android.app.prediction.AppTargetEvent;
+import android.app.prediction.AppTargetId;
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.content.pm.ResolveInfo;
+import android.os.UserHandle;
+import android.view.textclassifier.Log;
+
+import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * Uses an {@link AppPredictor} to sort Resolver targets.
+ */
+class AppPredictionServiceResolverComparator extends AbstractResolverComparator {
+
+ private static final String TAG = "APSResolverComparator";
+
+ private final AppPredictor mAppPredictor;
+ private final Context mContext;
+ private final Map<ComponentName, Integer> mTargetRanks = new HashMap<>();
+ private final UserHandle mUser;
+
+ AppPredictionServiceResolverComparator(
+ Context context, Intent intent, AppPredictor appPredictor, UserHandle user) {
+ super(context, intent);
+ mContext = context;
+ mAppPredictor = appPredictor;
+ mUser = user;
+ }
+
+ @Override
+ int compare(ResolveInfo lhs, ResolveInfo rhs) {
+ Integer lhsRank = mTargetRanks.get(new ComponentName(lhs.activityInfo.packageName,
+ lhs.activityInfo.name));
+ Integer rhsRank = mTargetRanks.get(new ComponentName(rhs.activityInfo.packageName,
+ rhs.activityInfo.name));
+ if (lhsRank == null && rhsRank == null) {
+ return 0;
+ } else if (lhsRank == null) {
+ return -1;
+ } else if (rhsRank == null) {
+ return 1;
+ }
+ return lhsRank - rhsRank;
+ }
+
+ @Override
+ void compute(List<ResolvedComponentInfo> targets) {
+ List<AppTarget> appTargets = new ArrayList<>();
+ for (ResolvedComponentInfo target : targets) {
+ appTargets.add(new AppTarget.Builder(new AppTargetId(target.name.flattenToString()))
+ .setTarget(target.name.getPackageName(), mUser)
+ .setClassName(target.name.getClassName()).build());
+ }
+ mAppPredictor.sortTargets(appTargets, mContext.getMainExecutor(),
+ sortedAppTargets -> {
+ for (int i = 0; i < sortedAppTargets.size(); i++) {
+ mTargetRanks.put(new ComponentName(sortedAppTargets.get(i).getPackageName(),
+ sortedAppTargets.get(i).getClassName()), i);
+ }
+ afterCompute();
+ });
+ }
+
+ @Override
+ float getScore(ComponentName name) {
+ Integer rank = mTargetRanks.get(name);
+ if (rank == null) {
+ Log.w(TAG, "Score requested for unknown component.");
+ return 0f;
+ }
+ int consecutiveSumOfRanks = (mTargetRanks.size() - 1) * (mTargetRanks.size()) / 2;
+ return 1.0f - (((float) rank) / consecutiveSumOfRanks);
+ }
+
+ @Override
+ void updateModel(ComponentName componentName) {
+ mAppPredictor.notifyAppTargetEvent(
+ new AppTargetEvent.Builder(
+ new AppTarget.Builder(
+ new AppTargetId(componentName.toString()),
+ componentName.getPackageName(), mUser)
+ .setClassName(componentName.getClassName()).build(),
+ ACTION_LAUNCH).build());
+ }
+
+ @Override
+ void destroy() {
+ // Do nothing. App Predictor destruction is handled by caller.
+ }
+}
diff --git a/core/java/com/android/internal/app/ChooserActivity.java b/core/java/com/android/internal/app/ChooserActivity.java
index 54338bf..59e867f 100644
--- a/core/java/com/android/internal/app/ChooserActivity.java
+++ b/core/java/com/android/internal/app/ChooserActivity.java
@@ -150,6 +150,7 @@
*/
// TODO(b/123089490): Replace with system flag
private static final boolean USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS = false;
+ private static final boolean USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES = false;
// TODO(b/123088566) Share these in a better way.
private static final String APP_PREDICTION_SHARE_UI_SURFACE = "share";
public static final String LAUNCH_LOCATON_DIRECT_SHARE = "direct_share";
@@ -1387,6 +1388,15 @@
return USE_PREDICTION_MANAGER_FOR_DIRECT_TARGETS ? getAppPredictor() : null;
}
+ /**
+ * This will return an app predictor if it is enabled for share activity sorting
+ * and if one exists. Otherwise, it returns null.
+ */
+ @Nullable
+ private AppPredictor getAppPredictorForShareActivitesIfEnabled() {
+ return USE_PREDICTION_MANAGER_FOR_SHARE_ACTIVITIES ? getAppPredictor() : null;
+ }
+
void onRefinementResult(TargetInfo selectedTarget, Intent matchingIntent) {
if (mRefinementResultReceiver != null) {
mRefinementResultReceiver.destroy();
@@ -1491,8 +1501,10 @@
PackageManager pm,
Intent targetIntent,
String referrerPackageName,
- int launchedFromUid) {
- super(context, pm, targetIntent, referrerPackageName, launchedFromUid);
+ int launchedFromUid,
+ AbstractResolverComparator resolverComparator) {
+ super(context, pm, targetIntent, referrerPackageName, launchedFromUid,
+ resolverComparator);
}
@Override
@@ -1520,13 +1532,24 @@
@VisibleForTesting
protected ResolverListController createListController() {
+ AppPredictor appPredictor = getAppPredictorForShareActivitesIfEnabled();
+ AbstractResolverComparator resolverComparator;
+ if (appPredictor != null) {
+ resolverComparator = new AppPredictionServiceResolverComparator(this, getTargetIntent(),
+ appPredictor, getUser());
+ } else {
+ resolverComparator =
+ new ResolverRankerServiceResolverComparator(this, getTargetIntent(),
+ getReferrerPackageName(), null);
+ }
+
return new ChooserListController(
this,
mPm,
getTargetIntent(),
getReferrerPackageName(),
- mLaunchedFromUid
- );
+ mLaunchedFromUid,
+ resolverComparator);
}
@VisibleForTesting
diff --git a/core/java/com/android/internal/app/ResolverListController.java b/core/java/com/android/internal/app/ResolverListController.java
index a3cfa87..5f92cdd 100644
--- a/core/java/com/android/internal/app/ResolverListController.java
+++ b/core/java/com/android/internal/app/ResolverListController.java
@@ -63,14 +63,24 @@
Intent targetIntent,
String referrerPackage,
int launchedFromUid) {
+ this(context, pm, targetIntent, referrerPackage, launchedFromUid,
+ new ResolverRankerServiceResolverComparator(
+ context, targetIntent, referrerPackage, null));
+ }
+
+ public ResolverListController(
+ Context context,
+ PackageManager pm,
+ Intent targetIntent,
+ String referrerPackage,
+ int launchedFromUid,
+ AbstractResolverComparator resolverComparator) {
mContext = context;
mpm = pm;
mLaunchedFromUid = launchedFromUid;
mTargetIntent = targetIntent;
mReferrerPackage = referrerPackage;
- mResolverComparator =
- new ResolverRankerServiceResolverComparator(
- mContext, mTargetIntent, mReferrerPackage, null);
+ mResolverComparator = resolverComparator;
}
@VisibleForTesting
diff --git a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
index 9bf4f01..726b186 100644
--- a/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
+++ b/core/java/com/android/internal/app/ResolverRankerServiceResolverComparator.java
@@ -126,7 +126,7 @@
Log.e(TAG, "Receiving null prediction results.");
}
mHandler.removeMessages(RESOLVER_RANKER_RESULT_TIMEOUT);
- mAfterCompute.afterCompute();
+ afterCompute();
}
break;
@@ -135,7 +135,7 @@
Log.d(TAG, "RESOLVER_RANKER_RESULT_TIMEOUT; unbinding services");
}
mHandler.removeMessages(RESOLVER_RANKER_SERVICE_RESULT);
- mAfterCompute.afterCompute();
+ afterCompute();
break;
default:
@@ -149,7 +149,6 @@
super(context, intent);
mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
mReferrerPackage = referrerPackage;
- mAfterCompute = afterCompute;
mContext = context;
mCurrentTime = System.currentTimeMillis();
@@ -157,6 +156,7 @@
mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
mAction = intent.getAction();
mRankerServiceName = new ComponentName(mContext, this.getClass());
+ setCallBack(afterCompute);
}
// compute features for each target according to usage stats of targets.
@@ -328,9 +328,7 @@
mContext.unbindService(mConnection);
mConnection.destroy();
}
- if (mAfterCompute != null) {
- mAfterCompute.afterCompute();
- }
+ afterCompute();
if (DEBUG) {
Log.d(TAG, "Unbinded Resolver Ranker.");
}
@@ -513,9 +511,7 @@
Log.e(TAG, "Error in Predict: " + e);
}
}
- if (mAfterCompute != null) {
- mAfterCompute.afterCompute();
- }
+ afterCompute();
}
// adds select prob as the default values, according to a pre-trained Logistic Regression model.