Merge "Rank apps by Logistic Regression for Smart-Sharing."
diff --git a/core/java/com/android/internal/app/ChooserActivity.java b/core/java/com/android/internal/app/ChooserActivity.java
index cb7be2e..991ba78 100644
--- a/core/java/com/android/internal/app/ChooserActivity.java
+++ b/core/java/com/android/internal/app/ChooserActivity.java
@@ -565,6 +565,7 @@
             if (ri != null && ri.activityInfo != null) {
                 usageStatsManager.reportChooserSelection(ri.activityInfo.packageName, getUserId(),
                         annotation, null, info.getResolvedIntent().getAction());
+                mResolverComparator.updateModel(info.getResolvedComponentName());
                 if (DEBUG) {
                     Log.d(TAG, "ResolveInfo Package is" + ri.activityInfo.packageName);
                 }
diff --git a/core/java/com/android/internal/app/ResolverActivity.java b/core/java/com/android/internal/app/ResolverActivity.java
index 7c22c4f..f2bd701 100644
--- a/core/java/com/android/internal/app/ResolverActivity.java
+++ b/core/java/com/android/internal/app/ResolverActivity.java
@@ -107,6 +107,7 @@
     private PickTargetOptionRequest mPickOptionRequest;
     private String mReferrerPackage;
 
+    protected ResolverComparator mResolverComparator;
     protected ResolverDrawerLayout mResolverDrawerLayout;
     protected String mContentType;
     protected PackageManager mPm;
diff --git a/core/java/com/android/internal/app/ResolverComparator.java b/core/java/com/android/internal/app/ResolverComparator.java
index 75be906..45fad97 100644
--- a/core/java/com/android/internal/app/ResolverComparator.java
+++ b/core/java/com/android/internal/app/ResolverComparator.java
@@ -27,11 +27,16 @@
 import android.content.pm.ComponentInfo;
 import android.content.pm.PackageManager;
 import android.content.pm.ResolveInfo;
+import android.content.SharedPreferences;
+import android.os.Environment;
+import android.os.storage.StorageManager;
 import android.os.UserHandle;
 import android.text.TextUtils;
+import android.util.ArrayMap;
 import android.util.Log;
 import com.android.internal.app.ResolverActivity.ResolvedComponentInfo;
 
+import java.io.File;
 import java.text.Collator;
 import java.util.ArrayList;
 import java.util.Comparator;
@@ -54,6 +59,12 @@
 
     private static final float RECENCY_MULTIPLIER = 2.f;
 
+    // feature names used in ranking.
+    private static final String LAUNCH_SCORE = "launch";
+    private static final String TIME_SPENT_SCORE = "timeSpent";
+    private static final String RECENCY_SCORE = "recency";
+    private static final String CHOOSER_SCORE = "chooser";
+
     private final Collator mCollator;
     private final boolean mHttp;
     private final PackageManager mPm;
@@ -65,6 +76,7 @@
     private final String mReferrerPackage;
     public String mContentType;
     private String mAction;
+    private LogisticRegressionAppRanker mRanker;
 
     public ResolverComparator(Context context, Intent intent, String referrerPackage) {
         mCollator = Collator.getInstance(context.getResources().getConfiguration().locale);
@@ -80,6 +92,7 @@
         mStats = mUsm.queryAndAggregateUsageStats(mSinceTime, mCurrentTime);
         mContentType = intent.getType();
         mAction = intent.getAction();
+        mRanker = new LogisticRegressionAppRanker(context);
     }
 
     public void compute(List<ResolvedComponentInfo> targets) {
@@ -152,16 +165,13 @@
         for (ScoredTarget target : mScoredTargets.values()) {
             final float recency = (float) Math.max(target.lastTimeUsed - recentSinceTime, 0)
                     / (mostRecentlyUsedTime - recentSinceTime);
-            final float recencyScore = recency * recency * RECENCY_MULTIPLIER;
-            final float usageTimeScore = (float) target.timeSpent / mostTimeSpent;
-            final float launchCountScore = (float) target.launchCount / mostLaunched;
-
-            target.score = recencyScore + usageTimeScore + launchCountScore;
+            target.setFeatures((float) target.launchCount / mostLaunched,
+                    (float) target.timeSpent / mostTimeSpent,
+                    recency * recency * RECENCY_MULTIPLIER,
+                    (float) target.chooserCount / mostSelected);
+            target.selectProb = mRanker.predict(target.getFeatures());
             if (DEBUG) {
-                Log.d(TAG, "Scores: recencyScore: " + recencyScore
-                        + " usageTimeScore: " + usageTimeScore
-                        + " launchCountScore: " + launchCountScore
-                        + " - " + target);
+                Log.d(TAG, "Scores: " + target);
             }
         }
     }
@@ -215,17 +225,11 @@
                 final ScoredTarget rhsTarget = mScoredTargets.get(new ComponentName(
                         rhs.activityInfo.packageName, rhs.activityInfo.name));
 
-                final int chooserCountDiff = Long.compare(
-                        rhsTarget.chooserCount, lhsTarget.chooserCount);
+                final int selectProbDiff = Float.compare(
+                        rhsTarget.selectProb, lhsTarget.selectProb);
 
-                if (chooserCountDiff != 0) {
-                    return chooserCountDiff > 0 ? 1 : -1;
-                }
-
-                final int diff = Float.compare(rhsTarget.score, lhsTarget.score);
-
-                if (diff != 0) {
-                    return diff > 0 ? 1 : -1;
+                if (selectProbDiff != 0) {
+                    return selectProbDiff > 0 ? 1 : -1;
                 }
             }
         }
@@ -241,32 +245,160 @@
     public float getScore(ComponentName name) {
         final ScoredTarget target = mScoredTargets.get(name);
         if (target != null) {
-            return target.score;
+            return target.selectProb;
         }
         return 0;
     }
 
     static class ScoredTarget {
         public final ComponentInfo componentInfo;
-        public float score;
         public long lastTimeUsed;
         public long timeSpent;
         public long launchCount;
         public long chooserCount;
+        public ArrayMap<String, Float> features;
+        public float selectProb;
 
         public ScoredTarget(ComponentInfo ci) {
             componentInfo = ci;
+            features = new ArrayMap<>(5);
         }
 
         @Override
         public String toString() {
             return "ScoredTarget{" + componentInfo
-                    + " score: " + score
                     + " lastTimeUsed: " + lastTimeUsed
                     + " timeSpent: " + timeSpent
                     + " launchCount: " + launchCount
                     + " chooserCount: " + chooserCount
+                    + " selectProb: " + selectProb
                     + "}";
         }
+
+        public void setFeatures(float launchCountScore, float usageTimeScore, float recencyScore,
+                                float chooserCountScore) {
+            features.put(LAUNCH_SCORE, launchCountScore);
+            features.put(TIME_SPENT_SCORE, usageTimeScore);
+            features.put(RECENCY_SCORE, recencyScore);
+            features.put(CHOOSER_SCORE, chooserCountScore);
+        }
+
+        public ArrayMap<String, Float> getFeatures() {
+            return features;
+        }
+    }
+
+    public void updateModel(ComponentName componentName) {
+        if (mScoredTargets == null || componentName == null ||
+                !mScoredTargets.containsKey(componentName)) {
+            return;
+        }
+        ScoredTarget selected = mScoredTargets.get(componentName);
+        for (ComponentName targetComponent : mScoredTargets.keySet()) {
+            if (targetComponent.equals(componentName)) {
+                continue;
+            }
+            ScoredTarget target = mScoredTargets.get(targetComponent);
+            // A potential point of optimization. Save updates or derive a closed form for the
+            // positive case, to avoid calculating them repeatedly.
+            if (target.selectProb >= selected.selectProb) {
+                mRanker.update(target.getFeatures(), target.selectProb, false);
+                mRanker.update(selected.getFeatures(), selected.selectProb, true);
+            }
+        }
+        mRanker.commitUpdate();
+    }
+
+    class LogisticRegressionAppRanker {
+        private static final String PARAM_SHARED_PREF_NAME = "resolver_ranker_params";
+        private static final String BIAS_PREF_KEY = "bias";
+        private static final float LEARNING_RATE = 0.02f;
+        private static final float REGULARIZER_PARAM = 0.1f;
+        private SharedPreferences mParamSharedPref;
+        private ArrayMap<String, Float> mFeatureWeights;
+        private float mBias;
+
+        public LogisticRegressionAppRanker(Context context) {
+            mParamSharedPref = getParamSharedPref(context);
+        }
+
+        public float predict(ArrayMap<String, Float> target) {
+            if (target == null || mParamSharedPref == null) {
+                return 0.0f;
+            }
+            final int featureSize = target.size();
+            if (featureSize == 0) {
+                return 0.0f;
+            }
+            float sum = 0.0f;
+            if (mFeatureWeights == null) {
+                mBias = mParamSharedPref.getFloat(BIAS_PREF_KEY, 0.0f);
+                mFeatureWeights = new ArrayMap<>(featureSize);
+                for (int i = 0; i < featureSize; i++) {
+                    String featureName = target.keyAt(i);
+                    float weight = mParamSharedPref.getFloat(featureName, 0.0f);
+                    sum += weight * target.valueAt(i);
+                    mFeatureWeights.put(featureName, weight);
+                }
+            } else {
+                for (int i = 0; i < featureSize; i++) {
+                    String featureName = target.keyAt(i);
+                    float weight = mFeatureWeights.getOrDefault(featureName, 0.0f);
+                    sum += weight * target.valueAt(i);
+                }
+            }
+            return (float) (1.0 / (1.0 + Math.exp(-mBias - sum)));
+        }
+
+        public void update(ArrayMap<String, Float> target, float predict, boolean isSelected) {
+            if (target == null || target.size() == 0) {
+                return;
+            }
+            final int featureSize = target.size();
+            if (mFeatureWeights == null) {
+                mBias = 0.0f;
+                mFeatureWeights = new ArrayMap<>(featureSize);
+            }
+            float error = isSelected ? 1.0f - predict : -predict;
+            for (int i = 0; i < featureSize; i++) {
+                String featureName = target.keyAt(i);
+                float currentWeight = mFeatureWeights.getOrDefault(featureName, 0.0f);
+                mBias += LEARNING_RATE * error;
+                currentWeight = currentWeight - LEARNING_RATE * REGULARIZER_PARAM * currentWeight +
+                        LEARNING_RATE * error * target.valueAt(i);
+                mFeatureWeights.put(featureName, currentWeight);
+            }
+            if (DEBUG) {
+                Log.d(TAG, "Weights: " + mFeatureWeights + " Bias: " + mBias);
+            }
+        }
+
+        public void commitUpdate() {
+            if (mFeatureWeights == null || mFeatureWeights.size() == 0) {
+                return;
+            }
+            SharedPreferences.Editor editor = mParamSharedPref.edit();
+            editor.putFloat(BIAS_PREF_KEY, mBias);
+            final int size = mFeatureWeights.size();
+            for (int i = 0; i < size; i++) {
+                editor.putFloat(mFeatureWeights.keyAt(i), mFeatureWeights.valueAt(i));
+            }
+            editor.apply();
+        }
+
+        private SharedPreferences getParamSharedPref(Context context) {
+            // The package info in the context isn't initialized in the way it is for normal apps,
+            // so the standard, name-based context.getSharedPreferences doesn't work. Instead, we
+            // build the path manually below using the same policy that appears in ContextImpl.
+            if (DEBUG) {
+                Log.d(TAG, "Context Package Name: " + context.getPackageName());
+            }
+            final File prefsFile = new File(new File(
+                    Environment.getDataUserCePackageDirectory(StorageManager.UUID_PRIVATE_INTERNAL,
+                            context.getUserId(), context.getPackageName()),
+                    "shared_prefs"),
+                    PARAM_SHARED_PREF_NAME + ".xml");
+            return context.getSharedPreferences(prefsFile, Context.MODE_PRIVATE);
+        }
     }
 }