Merge "Merge "Add icon packs for themepicker." into qt-dev am: 2c09b3be23" into qt-dev-plus-aosp
diff --git a/packages/ExtServices/src/android/ext/services/notification/Assistant.java b/packages/ExtServices/src/android/ext/services/notification/Assistant.java
index b2baff5..73c7895 100644
--- a/packages/ExtServices/src/android/ext/services/notification/Assistant.java
+++ b/packages/ExtServices/src/android/ext/services/notification/Assistant.java
@@ -23,6 +23,7 @@
 
 import android.annotation.NonNull;
 import android.annotation.Nullable;
+import android.annotation.SuppressLint;
 import android.app.ActivityThread;
 import android.app.INotificationManager;
 import android.app.Notification;
@@ -71,6 +72,7 @@
 /**
  * Notification assistant that provides guidance on notification channel blocking
  */
+@SuppressLint("OverrideAbstract")
 public class Assistant extends NotificationAssistantService {
     private static final String TAG = "ExtAssistant";
     private static final boolean DEBUG = Log.isLoggable(TAG, Log.DEBUG);
@@ -238,7 +240,7 @@
         }
         mSingleThreadExecutor.submit(() -> {
             NotificationEntry entry =
-                    new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+                    new NotificationEntry(getContext(), mPackageManager, sbn, channel, mSmsHelper);
             SmartActionsHelper.SmartSuggestions suggestions = mSmartActionsHelper.suggest(entry);
             if (DEBUG) {
                 Log.d(TAG, String.format(
@@ -295,7 +297,7 @@
             }
             Ranking ranking = getRanking(sbn.getKey(), rankingMap);
             if (ranking != null && ranking.getChannel() != null) {
-                NotificationEntry entry = new NotificationEntry(mPackageManager,
+                NotificationEntry entry = new NotificationEntry(getContext(), mPackageManager,
                         sbn, ranking.getChannel(), mSmsHelper);
                 String key = getKey(
                         sbn.getPackageName(), sbn.getUserId(), ranking.getChannel().getId());
diff --git a/packages/ExtServices/src/android/ext/services/notification/NotificationEntry.java b/packages/ExtServices/src/android/ext/services/notification/NotificationEntry.java
index 84a8a8c..1ffbac9 100644
--- a/packages/ExtServices/src/android/ext/services/notification/NotificationEntry.java
+++ b/packages/ExtServices/src/android/ext/services/notification/NotificationEntry.java
@@ -28,18 +28,23 @@
 import android.app.Person;
 import android.app.RemoteInput;
 import android.content.ComponentName;
+import android.content.Context;
 import android.content.pm.ApplicationInfo;
 import android.content.pm.IPackageManager;
 import android.content.pm.PackageManager;
+import android.graphics.drawable.Icon;
 import android.media.AudioAttributes;
 import android.media.AudioSystem;
 import android.os.Build;
+import android.os.Parcelable;
 import android.os.RemoteException;
 import android.service.notification.StatusBarNotification;
 import android.util.Log;
+import android.util.SparseArray;
 
 import java.util.ArrayList;
 import java.util.Objects;
+import java.util.Set;
 
 /**
  * Holds data about notifications.
@@ -47,6 +52,10 @@
 public class NotificationEntry {
     static final String TAG = "NotificationEntry";
 
+    // Copied from hidden definitions in Notification.TvExtender
+    private static final String EXTRA_TV_EXTENDER = "android.tv.EXTENSIONS";
+
+    private final Context mContext;
     private final StatusBarNotification mSbn;
     private final IPackageManager mPackageManager;
     private int mTargetSdkVersion = Build.VERSION_CODES.N_MR1;
@@ -60,9 +69,10 @@
 
     private final Object mLock = new Object();
 
-    public NotificationEntry(IPackageManager packageManager, StatusBarNotification sbn,
-            NotificationChannel channel, SmsHelper smsHelper) {
-        mSbn = sbn;
+    public NotificationEntry(Context applicationContext, IPackageManager packageManager,
+            StatusBarNotification sbn, NotificationChannel channel, SmsHelper smsHelper) {
+        mContext = applicationContext;
+        mSbn = cloneStatusBarNotificationLight(sbn);
         mChannel = channel;
         mPackageManager = packageManager;
         mPreChannelsNotification = isPreChannelsNotification();
@@ -71,6 +81,66 @@
         mSmsHelper = smsHelper;
     }
 
+    /** Adapted from {@code Notification.lightenPayload}. */
+    @SuppressWarnings("nullness")
+    private static void lightenNotificationPayload(Notification notification) {
+        notification.tickerView = null;
+        notification.contentView = null;
+        notification.bigContentView = null;
+        notification.headsUpContentView = null;
+        notification.largeIcon = null;
+        if (notification.extras != null && !notification.extras.isEmpty()) {
+            final Set<String> keyset = notification.extras.keySet();
+            final int keysetSize = keyset.size();
+            final String[] keys = keyset.toArray(new String[keysetSize]);
+            for (int i = 0; i < keysetSize; i++) {
+                final String key = keys[i];
+                if (EXTRA_TV_EXTENDER.equals(key)
+                        || Notification.EXTRA_MESSAGES.equals(key)
+                        || Notification.EXTRA_MESSAGING_PERSON.equals(key)
+                        || Notification.EXTRA_PEOPLE_LIST.equals(key)) {
+                    continue;
+                }
+                final Object obj = notification.extras.get(key);
+                if (obj != null
+                        && (obj instanceof Parcelable
+                        || obj instanceof Parcelable[]
+                        || obj instanceof SparseArray
+                        || obj instanceof ArrayList)) {
+                    notification.extras.remove(key);
+                }
+            }
+        }
+    }
+
+    /** An interpretation of {@code Notification.cloneInto} with heavy=false. */
+    private Notification cloneNotificationLight(Notification notification) {
+        // We can't just use clone() here because the only way to remove the icons is with the
+        // builder, which we can only create with a Context.
+        Notification lightNotification =
+                Notification.Builder.recoverBuilder(mContext, notification)
+                        .setSmallIcon(0)
+                        .setLargeIcon((Icon) null)
+                        .build();
+        lightenNotificationPayload(lightNotification);
+        return lightNotification;
+    }
+
+    /** Adapted from {@code StatusBarNotification.cloneLight}. */
+    public StatusBarNotification cloneStatusBarNotificationLight(StatusBarNotification sbn) {
+        return new StatusBarNotification(
+                sbn.getPackageName(),
+                sbn.getOpPkg(),
+                sbn.getId(),
+                sbn.getTag(),
+                sbn.getUid(),
+                /*initialPid=*/ 0,
+                /*score=*/ 0,
+                cloneNotificationLight(sbn.getNotification()),
+                sbn.getUser(),
+                sbn.getPostTime());
+    }
+
     private boolean isPreChannelsNotification() {
         try {
             ApplicationInfo info = mPackageManager.getApplicationInfo(
diff --git a/packages/ExtServices/tests/src/android/ext/services/notification/AgingHelperTest.java b/packages/ExtServices/tests/src/android/ext/services/notification/AgingHelperTest.java
index 3db275a..a87d57c 100644
--- a/packages/ExtServices/tests/src/android/ext/services/notification/AgingHelperTest.java
+++ b/packages/ExtServices/tests/src/android/ext/services/notification/AgingHelperTest.java
@@ -102,7 +102,8 @@
     public void testNoSnoozingOnPost() {
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         StatusBarNotification sbn = generateSbn(channel.getId());
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
 
 
         mAgingHelper.onNotificationPosted(entry);
@@ -113,7 +114,8 @@
     public void testPostResetsSnooze() {
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         StatusBarNotification sbn = generateSbn(channel.getId());
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
 
 
         mAgingHelper.onNotificationPosted(entry);
@@ -124,7 +126,8 @@
     public void testSnoozingOnSeen() {
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         StatusBarNotification sbn = generateSbn(channel.getId());
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
         entry.setSeen();
         when(mCategorizer.getCategory(entry)).thenReturn(NotificationCategorizer.CATEGORY_PEOPLE);
 
@@ -137,7 +140,8 @@
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         channel.lockFields(NotificationChannel.USER_LOCKED_IMPORTANCE);
         StatusBarNotification sbn = generateSbn(channel.getId());
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
         when(mCategorizer.getCategory(entry)).thenReturn(NotificationCategorizer.CATEGORY_PEOPLE);
 
         mAgingHelper.onNotificationSeen(entry);
diff --git a/packages/ExtServices/tests/src/android/ext/services/notification/AssistantTest.java b/packages/ExtServices/tests/src/android/ext/services/notification/AssistantTest.java
index ee29bc5..012dcc0 100644
--- a/packages/ExtServices/tests/src/android/ext/services/notification/AssistantTest.java
+++ b/packages/ExtServices/tests/src/android/ext/services/notification/AssistantTest.java
@@ -468,8 +468,10 @@
     @Test
     public void testAssistantNeverIncreasesImportanceWhenSuggestingSilent() throws Exception {
         StatusBarNotification sbn = generateSbn(PKG1, UID1, P1C3, "min notif!", null);
-        Adjustment adjust = mAssistant.createEnqueuedNotificationAdjustment(new NotificationEntry(
-                mPackageManager, sbn, P1C3, mSmsHelper), new ArrayList<>(), new ArrayList<>());
+        Adjustment adjust = mAssistant.createEnqueuedNotificationAdjustment(
+                new NotificationEntry(mContext, mPackageManager, sbn, P1C3, mSmsHelper),
+                new ArrayList<>(),
+                new ArrayList<>());
         assertEquals(IMPORTANCE_MIN, adjust.getSignals().getInt(Adjustment.KEY_IMPORTANCE));
     }
 }
diff --git a/packages/ExtServices/tests/src/android/ext/services/notification/NotificationEntryTest.java b/packages/ExtServices/tests/src/android/ext/services/notification/NotificationEntryTest.java
index f51e911..c026079 100644
--- a/packages/ExtServices/tests/src/android/ext/services/notification/NotificationEntryTest.java
+++ b/packages/ExtServices/tests/src/android/ext/services/notification/NotificationEntryTest.java
@@ -24,6 +24,7 @@
 import static junit.framework.Assert.assertFalse;
 import static junit.framework.Assert.assertTrue;
 
+import static org.junit.Assert.assertNull;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.when;
@@ -34,6 +35,8 @@
 import android.content.ComponentName;
 import android.content.pm.ApplicationInfo;
 import android.content.pm.IPackageManager;
+import android.graphics.Bitmap;
+import android.graphics.drawable.Icon;
 import android.media.AudioAttributes;
 import android.os.Build;
 import android.os.Process;
@@ -41,9 +44,6 @@
 import android.service.notification.StatusBarNotification;
 import android.testing.TestableContext;
 
-import androidx.test.InstrumentationRegistry;
-import androidx.test.runner.AndroidJUnit4;
-
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -53,6 +53,9 @@
 
 import java.util.ArrayList;
 
+import androidx.test.InstrumentationRegistry;
+import androidx.test.runner.AndroidJUnit4;
+
 @RunWith(AndroidJUnit4.class)
 public class NotificationEntryTest {
     private String mPkg = "pkg";
@@ -113,7 +116,8 @@
         people.add(new Person.Builder().setKey("mailto:testing@android.com").build());
         sbn.getNotification().extras.putParcelableArrayList(Notification.EXTRA_PEOPLE_LIST, people);
 
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
         assertTrue(entry.involvesPeople());
     }
 
@@ -121,7 +125,8 @@
     public void testNotPerson() {
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         StatusBarNotification sbn = generateSbn(channel.getId());
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
         assertFalse(entry.involvesPeople());
     }
 
@@ -129,7 +134,8 @@
     public void testHasPerson_matchesDefaultSmsApp() {
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         StatusBarNotification sbn = generateSbn(channel.getId(), DEFAULT_SMS_PACKAGE_NAME);
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
         assertTrue(entry.involvesPeople());
     }
 
@@ -137,7 +143,8 @@
     public void testHasPerson_doesntMatchDefaultSmsApp() {
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         StatusBarNotification sbn = generateSbn(channel.getId(), "abc");
-        NotificationEntry entry = new NotificationEntry(mPackageManager, sbn, channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, sbn, channel, mSmsHelper);
         assertFalse(entry.involvesPeople());
     }
 
@@ -148,8 +155,8 @@
         Notification n = new Notification.Builder(mContext, channel.getId())
                 .setStyle(new Notification.InboxStyle())
                 .build();
-        NotificationEntry entry =
-                new NotificationEntry(mPackageManager, generateSbn(n), channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, generateSbn(n), channel, mSmsHelper);
         assertTrue(entry.hasStyle(Notification.InboxStyle.class));
     }
 
@@ -160,8 +167,8 @@
         Notification n = new Notification.Builder(mContext, channel.getId())
                 .setStyle(new Notification.MessagingStyle(""))
                 .build();
-        NotificationEntry entry =
-                new NotificationEntry(mPackageManager, generateSbn(n), channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, generateSbn(n), channel, mSmsHelper);
         assertTrue(entry.hasStyle(Notification.MessagingStyle.class));
     }
 
@@ -172,8 +179,8 @@
         Notification n = new Notification.Builder(mContext, channel.getId())
                 .setStyle(new Notification.BigPictureStyle())
                 .build();
-        NotificationEntry entry =
-                new NotificationEntry(mPackageManager, generateSbn(n), channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, generateSbn(n), channel, mSmsHelper);
         assertFalse(entry.hasStyle(Notification.InboxStyle.class));
         assertFalse(entry.hasStyle(Notification.MessagingStyle.class));
     }
@@ -184,7 +191,7 @@
         channel.setSound(null, new AudioAttributes.Builder().setUsage(USAGE_ALARM).build());
 
         NotificationEntry entry = new NotificationEntry(
-                mPackageManager, generateSbn(channel.getId()), channel, mSmsHelper);
+                mContext, mPackageManager, generateSbn(channel.getId()), channel, mSmsHelper);
 
         assertTrue(entry.isAudioAttributesUsage(USAGE_ALARM));
     }
@@ -193,7 +200,7 @@
     public void testIsNotAudioAttributes() {
         NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
         NotificationEntry entry = new NotificationEntry(
-                mPackageManager, generateSbn(channel.getId()), channel, mSmsHelper);
+                mContext, mPackageManager, generateSbn(channel.getId()), channel, mSmsHelper);
 
         assertFalse(entry.isAudioAttributesUsage(USAGE_ALARM));
     }
@@ -205,8 +212,8 @@
         Notification n = new Notification.Builder(mContext, channel.getId())
                 .setCategory(Notification.CATEGORY_EMAIL)
                 .build();
-        NotificationEntry entry =
-                new NotificationEntry(mPackageManager, generateSbn(n), channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, generateSbn(n), channel, mSmsHelper);
 
         assertTrue(entry.isCategory(Notification.CATEGORY_EMAIL));
         assertFalse(entry.isCategory(Notification.CATEGORY_MESSAGE));
@@ -219,8 +226,8 @@
         Notification n = new Notification.Builder(mContext, channel.getId())
                 .setFlag(FLAG_FOREGROUND_SERVICE, true)
                 .build();
-        NotificationEntry entry =
-                new NotificationEntry(mPackageManager, generateSbn(n), channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, generateSbn(n), channel, mSmsHelper);
 
         assertTrue(entry.isOngoing());
     }
@@ -232,9 +239,28 @@
         Notification n = new Notification.Builder(mContext, channel.getId())
                 .setFlag(FLAG_CAN_COLORIZE, true)
                 .build();
-        NotificationEntry entry =
-                new NotificationEntry(mPackageManager, generateSbn(n), channel, mSmsHelper);
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, generateSbn(n), channel, mSmsHelper);
 
         assertFalse(entry.isOngoing());
     }
+
+    @Test
+    public void testShrinkNotification() {
+        Notification n = new Notification.Builder(mContext, "")
+                .setLargeIcon(Icon.createWithResource(
+                        mContext, android.R.drawable.alert_dark_frame))
+                .setSmallIcon(android.R.drawable.sym_def_app_icon)
+                .build();
+        n.largeIcon = Bitmap.createBitmap(100, 200, Bitmap.Config.RGB_565);
+        NotificationChannel channel = new NotificationChannel("", "", IMPORTANCE_HIGH);
+
+        NotificationEntry entry = new NotificationEntry(
+                mContext, mPackageManager, generateSbn(n), channel, mSmsHelper);
+
+        assertNull(entry.getNotification().getSmallIcon());
+        assertNull(entry.getNotification().getLargeIcon());
+        assertNull(entry.getNotification().largeIcon);
+        assertNull(entry.getNotification().extras.getParcelable(Notification.EXTRA_LARGE_ICON));
+    }
 }
diff --git a/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java b/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java
index 1b0631f..52b7225 100644
--- a/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java
+++ b/packages/ExtServices/tests/src/android/ext/services/notification/SmartActionsHelperTest.java
@@ -70,9 +70,11 @@
 
 import javax.annotation.Nullable;
 
+import androidx.test.InstrumentationRegistry;
+import androidx.test.runner.AndroidJUnit4;
+
 @RunWith(AndroidJUnit4.class)
 public class SmartActionsHelperTest {
-    private static final String NOTIFICATION_KEY = "key";
     private static final String RESULT_ID = "id";
     private static final float SCORE = 0.7f;
     private static final CharSequence SMART_REPLY = "Home";
@@ -87,7 +89,6 @@
     IPackageManager mIPackageManager;
     @Mock
     private TextClassifier mTextClassifier;
-    @Mock
     private StatusBarNotification mStatusBarNotification;
     @Mock
     private SmsHelper mSmsHelper;
@@ -107,9 +108,6 @@
         when(mTextClassifier.suggestConversationActions(any(ConversationActions.Request.class)))
                 .thenReturn(new ConversationActions(Arrays.asList(REPLY_ACTION), RESULT_ID));
 
-        when(mStatusBarNotification.getPackageName()).thenReturn("random.app");
-        when(mStatusBarNotification.getUser()).thenReturn(Process.myUserHandle());
-        when(mStatusBarNotification.getKey()).thenReturn(NOTIFICATION_KEY);
         mNotificationBuilder = new Notification.Builder(mContext, "channel");
         mSettings = AssistantSettings.createForTesting(
                 null, null, Process.myUserHandle().getIdentifier(), null);
@@ -118,10 +116,15 @@
         mSmartActionsHelper = new SmartActionsHelper(mContext, mSettings);
     }
 
+    private void setStatusBarNotification(Notification n) {
+        mStatusBarNotification = new StatusBarNotification("random.app", "random.app", 0,
+        "tag", Process.myUid(), Process.myPid(), n, Process.myUserHandle(), null, 0);
+    }
+
     @Test
     public void testSuggest_notMessageNotification() {
         Notification notification = mNotificationBuilder.setContentText(MESSAGE).build();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
 
@@ -136,7 +139,7 @@
                         .setContentText(MESSAGE)
                         .setCategory(Notification.CATEGORY_MESSAGE)
                         .build();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         ConversationActions.Request request = runSuggestAndCaptureRequest();
 
@@ -153,7 +156,7 @@
         mSettings.mGenerateActions = false;
         mSettings.mGenerateReplies = false;
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
 
@@ -166,7 +169,7 @@
         mSettings.mGenerateReplies = true;
         mSettings.mGenerateActions = false;
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         ConversationActions.Request request = runSuggestAndCaptureRequest();
 
@@ -183,7 +186,7 @@
         mSettings.mGenerateReplies = false;
         mSettings.mGenerateActions = true;
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         ConversationActions.Request request = runSuggestAndCaptureRequest();
 
@@ -199,7 +202,7 @@
     @Test
     public void testSuggest_nonMessageStyleMessageNotification() {
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         List<ConversationActions.Message> messages =
                 runSuggestAndCaptureRequest().getConversation();
@@ -232,7 +235,7 @@
                         .setStyle(style)
                         .setActions(createReplyAction())
                         .build();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         List<ConversationActions.Message> messages =
                 runSuggestAndCaptureRequest().getConversation();
@@ -287,7 +290,7 @@
                         .setStyle(style)
                         .setActions(createReplyAction())
                         .build();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
 
@@ -306,7 +309,7 @@
                         .setStyle(style)
                         .setActions(createReplyAction())
                         .build();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
 
@@ -317,11 +320,11 @@
     @Test
     public void testOnSuggestedReplySent() {
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
-        mSmartActionsHelper.onSuggestedReplySent(
-                NOTIFICATION_KEY, SMART_REPLY, NotificationAssistantService.SOURCE_FROM_ASSISTANT);
+        mSmartActionsHelper.onSuggestedReplySent(mStatusBarNotification.getKey(), SMART_REPLY,
+                NotificationAssistantService.SOURCE_FROM_ASSISTANT);
 
         ArgumentCaptor<TextClassifierEvent> argumentCaptor =
                 ArgumentCaptor.forClass(TextClassifierEvent.class);
@@ -337,7 +340,7 @@
     @Test
     public void testOnSuggestedReplySent_anotherNotification() {
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
         mSmartActionsHelper.onSuggestedReplySent(
@@ -352,11 +355,11 @@
         when(mTextClassifier.suggestConversationActions(any(ConversationActions.Request.class)))
                 .thenReturn(new ConversationActions(Collections.singletonList(REPLY_ACTION), null));
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
-        mSmartActionsHelper.onSuggestedReplySent(
-                NOTIFICATION_KEY, SMART_REPLY, NotificationAssistantService.SOURCE_FROM_ASSISTANT);
+        mSmartActionsHelper.onSuggestedReplySent(mStatusBarNotification.getKey(), SMART_REPLY,
+                NotificationAssistantService.SOURCE_FROM_ASSISTANT);
 
         verify(mTextClassifier, never()).onTextClassifierEvent(any(TextClassifierEvent.class));
     }
@@ -364,10 +367,10 @@
     @Test
     public void testOnNotificationDirectReply() {
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
-        mSmartActionsHelper.onNotificationDirectReplied(NOTIFICATION_KEY);
+        mSmartActionsHelper.onNotificationDirectReplied(mStatusBarNotification.getKey());
 
         ArgumentCaptor<TextClassifierEvent> argumentCaptor =
                 ArgumentCaptor.forClass(TextClassifierEvent.class);
@@ -380,7 +383,7 @@
     @Test
     public void testOnNotificationExpansionChanged() {
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
         mSmartActionsHelper.onNotificationExpansionChanged(createNotificationEntry(), true);
@@ -396,7 +399,7 @@
     @Test
     public void testOnNotificationsSeen_notExpanded() {
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
         mSmartActionsHelper.onNotificationExpansionChanged(createNotificationEntry(), false);
@@ -408,7 +411,7 @@
     @Test
     public void testOnNotifications_expanded() {
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
 
         mSmartActionsHelper.suggest(createNotificationEntry());
         mSmartActionsHelper.onNotificationExpansionChanged(createNotificationEntry(), true);
@@ -437,7 +440,7 @@
                                 Collections.singletonList(conversationAction), null));
 
         Notification notification = createMessageNotification();
-        when(mStatusBarNotification.getNotification()).thenReturn(notification);
+        setStatusBarNotification(notification);
         SmartActionsHelper.SmartSuggestions suggestions =
                 mSmartActionsHelper.suggest(createNotificationEntry());
 
@@ -476,7 +479,8 @@
     private NotificationEntry createNotificationEntry() {
         NotificationChannel channel =
                 new NotificationChannel("id", "name", NotificationManager.IMPORTANCE_DEFAULT);
-        return new NotificationEntry(mIPackageManager, mStatusBarNotification, channel, mSmsHelper);
+        return new NotificationEntry(
+                mContext, mIPackageManager, mStatusBarNotification, channel, mSmsHelper);
     }
 
     private Notification createMessageNotification() {