Do not create a TextClassificationSession if we are not going to use it. am: c94b3450ca am: 19df97df2f am: 9b5c3db0f2

Original change: https://googleplex-android-review.googlesource.com/c/platform/external/libtextclassifier/+/11734845

Change-Id: I5fdc82fcd236f5a5974f1f7a79aa1f8302434638
diff --git a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
index a6aa9ae..0a2cce7 100644
--- a/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
+++ b/notification/src/com/android/textclassifier/notification/SmartSuggestionsHelper.java
@@ -35,9 +35,12 @@
 import android.util.Pair;
 import android.view.textclassifier.ConversationAction;
 import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
 import android.view.textclassifier.TextClassificationContext;
 import android.view.textclassifier.TextClassificationManager;
 import android.view.textclassifier.TextClassifier;
+
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 import java.time.Instant;
@@ -48,6 +51,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.Optional;
 import javax.annotation.Nullable;
 
 /**
@@ -76,8 +80,9 @@
   private static final int MAX_RESULT_ID_TO_CACHE = 20;
   private static final ImmutableList<String> HINTS =
       ImmutableList.of(ConversationActions.Request.HINT_FOR_NOTIFICATION);
-  private static final ConversationActions EMPTY_CONVERSATION_ACTIONS =
-      new ConversationActions(ImmutableList.of(), null);
+  private static final SuggestConversationActionsResult EMPTY_SUGGEST_CONVERSATION_ACTION_RESULT =
+      new SuggestConversationActionsResult(
+          Optional.empty(), new ConversationActions(ImmutableList.of(), /* id= */ null));
 
   private final Context context;
   private final TextClassificationManager textClassificationManager;
@@ -119,19 +124,13 @@
     boolean eligibleForActionAdjustment =
         config.shouldGenerateActions() && isEligibleForActionAdjustment(statusBarNotification);
 
-    TextClassifier textClassifier =
-        textClassificationManager.createTextClassificationSession(textClassificationContext);
-
-    ConversationActions conversationActionsResult =
+    SuggestConversationActionsResult suggestConversationActionsResult =
         suggestConversationActions(
-            textClassifier,
-            statusBarNotification,
-            eligibleForReplyAdjustment,
-            eligibleForActionAdjustment);
+            statusBarNotification, eligibleForReplyAdjustment, eligibleForActionAdjustment);
 
-    String resultId = conversationActionsResult.getId();
+    String resultId = suggestConversationActionsResult.conversationActions.getId();
     List<ConversationAction> conversationActions =
-        conversationActionsResult.getConversationActions();
+        suggestConversationActionsResult.conversationActions.getConversationActions();
 
     ArrayList<CharSequence> replies = new ArrayList<>();
     Map<CharSequence, Float> repliesScore = new ArrayMap<>();
@@ -164,23 +163,31 @@
         actions.add(notificationAction);
       }
     }
-    if (TextUtils.isEmpty(resultId)) {
-      textClassifier.destroy();
-    } else {
-      SmartSuggestionsLogSession session =
-          new SmartSuggestionsLogSession(
-              resultId, repliesScore, textClassifier, textClassificationContext);
-      session.onSuggestionsGenerated(conversationActions);
 
-      // Store the session if we expect more logging from it, destroy it otherwise.
-      if (!conversationActions.isEmpty()
-          && suggestionsMightBeUsedInNotification(
-              statusBarNotification, !actions.isEmpty(), !replies.isEmpty())) {
-        sessionCache.put(statusBarNotification.getKey(), session);
-      } else {
-        session.destroy();
-      }
-    }
+    suggestConversationActionsResult.textClassifier.ifPresent(
+        textClassifier -> {
+          if (TextUtils.isEmpty(resultId)) {
+            // Missing the result id, skip logging.
+            textClassifier.destroy();
+          } else {
+            SmartSuggestionsLogSession session =
+                new SmartSuggestionsLogSession(
+                    resultId,
+                    repliesScore,
+                    textClassifier,
+                    textClassificationContext);
+            session.onSuggestionsGenerated(conversationActions);
+
+            // Store the session if we expect more logging from it, destroy it otherwise.
+            if (!conversationActions.isEmpty()
+                && suggestionsMightBeUsedInNotification(
+                    statusBarNotification, !actions.isEmpty(), !replies.isEmpty())) {
+              sessionCache.put(statusBarNotification.getKey(), session);
+            } else {
+              session.destroy();
+            }
+          }
+        });
 
     return new SmartSuggestions(replies, actions);
   }
@@ -260,23 +267,20 @@
   }
 
   /** Adds action adjustments based on the notification contents. */
-  private ConversationActions suggestConversationActions(
-      TextClassifier textClassifier,
-      StatusBarNotification statusBarNotification,
-      boolean includeReplies,
-      boolean includeActions) {
+  private SuggestConversationActionsResult suggestConversationActions(
+      StatusBarNotification statusBarNotification, boolean includeReplies, boolean includeActions) {
     if (!includeReplies && !includeActions) {
-      return EMPTY_CONVERSATION_ACTIONS;
+      return EMPTY_SUGGEST_CONVERSATION_ACTION_RESULT;
     }
     ImmutableList<ConversationActions.Message> messages =
         extractMessages(statusBarNotification.getNotification());
     if (messages.isEmpty()) {
-      return EMPTY_CONVERSATION_ACTIONS;
+      return EMPTY_SUGGEST_CONVERSATION_ACTION_RESULT;
     }
     // Do not generate smart actions if the last message is from the local user.
     ConversationActions.Message lastMessage = Iterables.getLast(messages);
     if (arePersonsEqual(ConversationActions.Message.PERSON_USER_SELF, lastMessage.getAuthor())) {
-      return EMPTY_CONVERSATION_ACTIONS;
+      return EMPTY_SUGGEST_CONVERSATION_ACTION_RESULT;
     }
 
     TextClassifier.EntityConfig.Builder typeConfigBuilder =
@@ -300,7 +304,9 @@
             .setTypeConfig(typeConfigBuilder.build())
             .build();
 
-    return textClassifier.suggestConversationActions(request);
+    TextClassifier textClassifier = createTextClassificationSession();
+    return new SuggestConversationActionsResult(
+        Optional.of(textClassifier), textClassifier.suggestConversationActions(request));
   }
 
   /**
@@ -464,9 +470,30 @@
     return ImmutableList.copyOf(new ArrayList<>(extractMessages));
   }
 
+  @VisibleForTesting
+  TextClassifier createTextClassificationSession() {
+    return textClassificationManager.createTextClassificationSession(textClassificationContext);
+  }
+
   private static boolean arePersonsEqual(Person left, Person right) {
     return Objects.equals(left.getKey(), right.getKey())
         && TextUtils.equals(left.getName(), right.getName())
         && Objects.equals(left.getUri(), right.getUri());
   }
+
+  /**
+   * Result object of {@link #suggestConversationActions(StatusBarNotification, boolean, boolean)}.
+   */
+  private static class SuggestConversationActionsResult {
+    /** The text classifier session that was involved to make suggestions, if any. */
+    final Optional<TextClassifier> textClassifier;
+    /** The resultant suggestions. */
+    final ConversationActions conversationActions;
+
+    SuggestConversationActionsResult(
+        Optional<TextClassifier> textClassifier, ConversationActions conversationActions) {
+      this.textClassifier = textClassifier;
+      this.conversationActions = conversationActions;
+    }
+  }
 }
diff --git a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
index 1cbfbf2..9d0a720 100644
--- a/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
+++ b/notification/tests/src/com/android/textclassifier/notification/SmartSuggestionsHelperTest.java
@@ -65,7 +65,7 @@
   private final Context context = ApplicationProvider.getApplicationContext();
   private final FakeTextClassifier fakeTextClassifier = new FakeTextClassifier();
   private final TestConfig config = new TestConfig();
-  private SmartSuggestionsHelper smartActions;
+  private TestableSmartSuggestionsHelper smartActions;
   private Notification.Builder notificationBuilder;
 
   @Before
@@ -73,10 +73,28 @@
     TextClassificationManager textClassificationManager =
         context.getSystemService(TextClassificationManager.class);
     textClassificationManager.setTextClassifier(fakeTextClassifier);
-    smartActions = new SmartSuggestionsHelper(context, config);
+    smartActions = new TestableSmartSuggestionsHelper(context, config);
     notificationBuilder = new Notification.Builder(context, "id");
   }
 
+  static class TestableSmartSuggestionsHelper extends SmartSuggestionsHelper {
+    private int numOfSessionsCreated = 0;
+
+    TestableSmartSuggestionsHelper(Context context, SmartSuggestionsConfig config) {
+      super(context, config);
+    }
+
+    @Override
+    TextClassifier createTextClassificationSession() {
+      numOfSessionsCreated += 1;
+      return super.createTextClassificationSession();
+    }
+
+    int getNumOfSessionsCreated() {
+      return numOfSessionsCreated;
+    }
+  }
+
   @Test
   public void onNotificationEnqueued_notMessageCategory() {
     Notification notification = notificationBuilder.setContentText(MESSAGE).build();
@@ -87,6 +105,8 @@
 
     assertThat(smartSuggestions.getReplies()).isEmpty();
     assertThat(smartSuggestions.getActions()).isEmpty();
+    // Ideally, we should verify that createTextClassificationSession
+    assertThat(smartActions.getNumOfSessionsCreated()).isEqualTo(0);
   }
 
   @Test
@@ -104,6 +124,7 @@
 
     assertThat(smartSuggestions.getReplies()).isEmpty();
     assertThat(smartSuggestions.getActions()).isEmpty();
+    assertThat(smartActions.getNumOfSessionsCreated()).isEqualTo(0);
   }
 
   @Test
@@ -120,6 +141,7 @@
 
     assertThat(smartSuggestions.getReplies()).isEmpty();
     assertAdjustmentWithSmartAction(smartSuggestions);
+    assertThat(smartActions.getNumOfSessionsCreated()).isEqualTo(1);
   }
 
   @Test
@@ -136,6 +158,7 @@
     List<Message> messages = request.getConversation();
     assertThat(messages).hasSize(1);
     assertThat(messages.get(0).getText().toString()).isEqualTo(MESSAGE);
+    assertThat(smartActions.getNumOfSessionsCreated()).isEqualTo(1);
   }
 
   @Test
@@ -169,6 +192,7 @@
     assertMessage(messages.get(1), "secondMessage", PERSON_USER_SELF, 2000);
     assertMessage(messages.get(2), "thirdMessage", userA, 3000);
     assertMessage(messages.get(3), "fourthMessage", userB, 4000);
+    assertThat(smartActions.getNumOfSessionsCreated()).isEqualTo(1);
   }
 
   @Test
@@ -192,6 +216,7 @@
 
     assertThat(smartSuggestions.getReplies()).isEmpty();
     assertThat(smartSuggestions.getActions()).isEmpty();
+    assertThat(smartActions.getNumOfSessionsCreated()).isEqualTo(0);
   }
 
   @Test
@@ -212,6 +237,7 @@
 
     assertThat(smartSuggestions.getReplies()).isEmpty();
     assertThat(smartSuggestions.getActions()).isEmpty();
+    assertThat(smartActions.getNumOfSessionsCreated()).isEqualTo(0);
   }
 
   @Test