Export to AOSP. am: 878a6ef0a1
Original change: https://googleplex-android-review.googlesource.com/c/platform/external/libtextclassifier/+/14516934
Change-Id: Iafeebf1b2b30064f7b4e050c0d9ca6832d2da2bb
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 4ca058d..1f1e958 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -16,10 +16,7 @@
package com.android.textclassifier;
-import android.content.BroadcastReceiver;
import android.content.Context;
-import android.content.Intent;
-import android.content.IntentFilter;
import android.os.CancellationSignal;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationActions;
@@ -64,7 +61,6 @@
private TextClassifierImpl textClassifier;
private TextClassifierSettings settings;
private ModelFileManager modelFileManager;
- private BroadcastReceiver localeChangedReceiver;
private LruCache<TextClassificationSessionId, TextClassificationContext> sessionIdToContext;
public DefaultTextClassifierService() {
@@ -87,20 +83,14 @@
normPriorityExecutor = injector.createNormPriorityExecutor();
lowPriorityExecutor = injector.createLowPriorityExecutor();
textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
- localeChangedReceiver = new LocaleChangedReceiver(modelFileManager);
sessionIdToContext = new LruCache<>(settings.getSessionIdToContextCacheSize());
textClassifierApiUsageLogger =
injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
-
- injector
- .getContext()
- .registerReceiver(localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
}
@Override
public void onDestroy() {
super.onDestroy();
- injector.getContext().unregisterReceiver(localeChangedReceiver);
}
@Override
@@ -284,22 +274,6 @@
return sessionIdToContext.get(sessionId);
}
- /**
- * Receiver listening to locale change event. Ask ModelFileManager to do clean-up upon receiving.
- */
- static class LocaleChangedReceiver extends BroadcastReceiver {
- private final ModelFileManager modelFileManager;
-
- LocaleChangedReceiver(ModelFileManager modelFileManager) {
- this.modelFileManager = modelFileManager;
- }
-
- @Override
- public void onReceive(Context context, Intent intent) {
- modelFileManager.deleteUnusedModelFiles();
- }
- }
-
// Do not call any of these methods, except the constructor, before Service.onCreate is called.
private static class InjectorImpl implements Injector {
// Do not access the context object before Service.onCreate is invoked.
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index 7383bc1..bf326fb 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -60,6 +60,7 @@
import com.android.textclassifier.common.statsd.TextClassifierEventLogger;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.android.textclassifier.ActionsSuggestionsModel;
+import com.google.android.textclassifier.ActionsSuggestionsModel.ActionSuggestions;
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
import com.google.common.base.Optional;
@@ -387,7 +388,7 @@
ActionsSuggestionsModel.Conversation nativeConversation =
new ActionsSuggestionsModel.Conversation(nativeMessages);
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions =
+ ActionSuggestions nativeSuggestions =
actionsImpl.suggestActionsWithIntents(
nativeConversation,
null,
@@ -404,11 +405,11 @@
* non-null component name is in the extras.
*/
private ConversationActions createConversationActionResult(
- ConversationActions.Request request,
- ActionsSuggestionsModel.ActionSuggestion[] nativeSuggestions) {
+ ConversationActions.Request request, ActionSuggestions nativeSuggestions) {
Collection<String> expectedTypes = resolveActionTypesFromRequest(request);
List<ConversationAction> conversationActions = new ArrayList<>();
- for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion : nativeSuggestions) {
+ for (ActionsSuggestionsModel.ActionSuggestion nativeSuggestion :
+ nativeSuggestions.actionSuggestions) {
String actionType = nativeSuggestion.getActionType();
if (!expectedTypes.contains(actionType)) {
continue;
@@ -690,7 +691,7 @@
private static void checkMainThread() {
if (Looper.myLooper() == Looper.getMainLooper()) {
- TcLog.e(TAG, "TextClassifier called on main thread", new Exception());
+ TcLog.e(TAG, "TCS TextClassifier called on main thread", new Exception());
}
}
diff --git a/java/src/com/android/textclassifier/common/ModelFileManager.java b/java/src/com/android/textclassifier/common/ModelFileManager.java
index a77bdd1..406a889 100644
--- a/java/src/com/android/textclassifier/common/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/common/ModelFileManager.java
@@ -374,27 +374,9 @@
}
}
- /**
- * Returns a {@link File} that represents the destination to download a model.
- *
- * <p>Each model file's name is uniquely formatted based on its unique remote manifest URL.
- *
- * <p>{@link ModelDownloadManager} needs to call this to get the right location and file name.
- *
- * @param modelType the type of the model image to download
- * @param manifestUrl the unique remote url of the model manifest
- */
- public File getDownloadTargetFile(@ModelType.ModelTypeDef String modelType, String manifestUrl) {
- // TODO(licha): Consider preserving the folder hierarchy of the URL
- String fileMidName = manifestUrl.replaceAll("[^A-Za-z0-9]", "_");
- if (fileMidName.startsWith("https___")) {
- fileMidName = fileMidName.substring("https___".length());
- }
- if (fileMidName.endsWith("_manifest")) {
- fileMidName = fileMidName.substring(0, fileMidName.length() - "_manifest".length());
- }
- String fileName = String.format("%s.%s.model", modelType, fileMidName);
- return new File(modelDownloaderDir, fileName);
+ /** Returns the directory containing models downloaded by the downloader. */
+ public File getModelDownloaderDir() {
+ return modelDownloaderDir;
}
/**
diff --git a/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java b/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
index 51703b7..43164e0 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierServiceExecutors.java
@@ -56,7 +56,14 @@
return MoreExecutors.listeningDecorator(
Executors.newFixedThreadPool(
corePoolSize,
- new ThreadFactoryBuilder().setNameFormat(nameFormat).setPriority(priority).build()));
+ new ThreadFactoryBuilder()
+ .setNameFormat(nameFormat)
+ .setPriority(priority)
+ // In Android, those uncaught exceptions will crash the whole process if not handled
+ .setUncaughtExceptionHandler(
+ (thread, throwable) ->
+ TcLog.e(TAG, "Exception from executor: " + thread, throwable))
+ .build()));
}
private TextClassifierServiceExecutors() {}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
index 4d5ca4a..40838ac 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
@@ -50,13 +50,9 @@
@RunWith(AndroidJUnit4.class)
public final class ModelFileManagerTest {
private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
- private static final String URL_SUFFIX = "q/711/en.fb";
- private static final String URL_SUFFIX_2 = "q/712/en.fb";
@ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
- @ModelTypeDef private static final String MODEL_TYPE_2 = ModelType.LANG_ID;
-
@Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
@Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
@@ -370,26 +366,6 @@
}
@Test
- public void getDownloadTargetFile_targetFileInCorrectDir() {
- File targetFile = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL_SUFFIX);
- assertThat(targetFile.getAbsolutePath())
- .startsWith(ApplicationProvider.getApplicationContext().getFilesDir().getAbsolutePath());
- }
-
- @Test
- public void getDownloadTargetFile_filePathIsUnique() {
- File targetFileOne = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL_SUFFIX);
- File targetFileTwo = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL_SUFFIX);
- File targetFileThree = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL_SUFFIX_2);
- File targetFileFour = modelFileManager.getDownloadTargetFile(MODEL_TYPE_2, URL_SUFFIX);
-
- assertThat(targetFileOne.getAbsolutePath()).isEqualTo(targetFileTwo.getAbsolutePath());
- assertThat(targetFileOne.getAbsolutePath()).isNotEqualTo(targetFileThree.getAbsolutePath());
- assertThat(targetFileOne.getAbsolutePath()).isNotEqualTo(targetFileFour.getAbsolutePath());
- assertThat(targetFileThree.getAbsolutePath()).isNotEqualTo(targetFileFour.getAbsolutePath());
- }
-
- @Test
public void modelFileEquals() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
diff --git a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
index ad3992d..b5c8ab6 100644
--- a/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
+++ b/jni/com/google/android/textclassifier/ActionsSuggestionsModel.java
@@ -87,7 +87,7 @@
}
/** Suggests actions / replies to the given conversation. */
- public ActionSuggestion[] suggestActions(
+ public ActionSuggestions suggestActions(
Conversation conversation, ActionSuggestionOptions options, AnnotatorModel annotator) {
return nativeSuggestActions(
actionsModelPtr,
@@ -99,7 +99,7 @@
/* generateAndroidIntents= */ false);
}
- public ActionSuggestion[] suggestActionsWithIntents(
+ public ActionSuggestions suggestActionsWithIntents(
Conversation conversation,
ActionSuggestionOptions options,
Object appContext,
@@ -179,6 +179,19 @@
}
}
+ /** Represents a list of suggested actions of a given conversation. */
+ public static final class ActionSuggestions {
+ /** A list of suggested actionsm sorted by score descendingly. */
+ public final ActionSuggestion[] actionSuggestions;
+ /** Whether the input conversation is considered as sensitive. */
+ public final boolean isSensitive;
+
+ public ActionSuggestions(ActionSuggestion[] actionSuggestions, boolean isSensitive) {
+ this.actionSuggestions = actionSuggestions;
+ this.isSensitive = isSensitive;
+ }
+ }
+
/** Action suggestion that contains a response text and the type of the response. */
public static final class ActionSuggestion {
@Nullable private final String responseText;
@@ -360,7 +373,7 @@
private static native String nativeGetNameWithOffset(int fd, long offset, long size);
- private native ActionSuggestion[] nativeSuggestActions(
+ private native ActionSuggestions nativeSuggestActions(
long context,
Conversation conversation,
ActionSuggestionOptions options,
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index 830976e..b1a042c 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -897,13 +897,12 @@
return false;
}
response->sensitivity_score = sensitive_topic_score.data()[0];
- response->output_filtered_sensitivity =
- (response->sensitivity_score >
- preconditions_.max_sensitive_topic_score);
+ response->is_sensitive = (response->sensitivity_score >
+ preconditions_.max_sensitive_topic_score);
}
// Suppress model outputs.
- if (response->output_filtered_sensitivity) {
+ if (response->is_sensitive) {
return true;
}
@@ -985,6 +984,12 @@
std::unique_ptr<tflite::Interpreter>* interpreter) const {
TC3_CHECK_LE(num_messages, conversation.messages.size());
+ if (sensitive_model_ != nullptr &&
+ sensitive_model_->EvalConversation(conversation, num_messages).first) {
+ response->is_sensitive = true;
+ return true;
+ }
+
if (!model_executor_) {
return true;
}
@@ -1357,11 +1362,7 @@
std::vector<const UniLib::RegexPattern*> post_check_rules;
if (preconditions_.suppress_on_low_confidence_input) {
- if ((sensitive_model_ != nullptr &&
- sensitive_model_
- ->EvalConversation(annotated_conversation, num_messages)
- .first) ||
- regex_actions_->IsLowConfidenceInput(annotated_conversation,
+ if (regex_actions_->IsLowConfidenceInput(annotated_conversation,
num_messages, &post_check_rules)) {
response->output_filtered_low_confidence = true;
return true;
@@ -1375,9 +1376,10 @@
return false;
}
+ // SuggestActionsFromModel also detects if the conversation is sensitive,
+ // either by using the old ngram model or the new model.
// Suppress all predictions if the conversation was deemed sensitive.
- if (preconditions_.suppress_on_sensitive_topic &&
- response->output_filtered_sensitivity) {
+ if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) {
return true;
}
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 6e66b32..7fe69fc 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -1821,7 +1821,8 @@
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 0);
- EXPECT_TRUE(response.output_filtered_low_confidence);
+ EXPECT_TRUE(response.is_sensitive);
+ EXPECT_FALSE(response.output_filtered_low_confidence);
}
} // namespace
diff --git a/native/actions/actions_jni.cc b/native/actions/actions_jni.cc
index 5981a17..9e15a2e 100644
--- a/native/actions/actions_jni.cc
+++ b/native/actions/actions_jni.cc
@@ -40,7 +40,6 @@
using libtextclassifier3::ActionsSuggestions;
using libtextclassifier3::ActionsSuggestionsResponse;
-using libtextclassifier3::ActionSuggestion;
using libtextclassifier3::ActionSuggestionOptions;
using libtextclassifier3::Annotator;
using libtextclassifier3::Conversation;
@@ -122,21 +121,34 @@
return options;
}
-StatusOr<ScopedLocalRef<jobjectArray>> ActionSuggestionsToJObjectArray(
+StatusOr<ScopedLocalRef<jobject>> ActionSuggestionsToJObject(
JNIEnv* env, const ActionsSuggestionsJniContext* context,
jobject app_context,
const reflection::Schema* annotations_entity_data_schema,
- const std::vector<ActionSuggestion>& action_result,
+ const ActionsSuggestionsResponse& action_response,
const Conversation& conversation, const jstring device_locales,
const bool generate_intents) {
- auto status_or_result_class = JniHelper::FindClass(
+ // Find the class ActionSuggestion.
+ auto status_or_action_class = JniHelper::FindClass(
env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ActionSuggestion");
- if (!status_or_result_class.ok()) {
+ if (!status_or_action_class.ok()) {
TC3_LOG(ERROR) << "Couldn't find ActionSuggestion class.";
+ return status_or_action_class.status();
+ }
+ ScopedLocalRef<jclass> action_class =
+ std::move(status_or_action_class.ValueOrDie());
+
+ // Find the class ActionSuggestions
+ auto status_or_result_class = JniHelper::FindClass(
+ env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$ActionSuggestions");
+ if (!status_or_result_class.ok()) {
+ TC3_LOG(ERROR) << "Couldn't find ActionSuggestions class.";
return status_or_result_class.status();
}
ScopedLocalRef<jclass> result_class =
std::move(status_or_result_class.ValueOrDie());
+
+ // Find the class Slot.
auto status_or_slot_class = JniHelper::FindClass(
env, TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR "$Slot");
if (!status_or_slot_class.ok()) {
@@ -147,9 +159,9 @@
std::move(status_or_slot_class.ValueOrDie());
TC3_ASSIGN_OR_RETURN(
- const jmethodID result_class_constructor,
+ const jmethodID action_class_constructor,
JniHelper::GetMethodID(
- env, result_class.get(), "<init>",
+ env, action_class.get(), "<init>",
"(Ljava/lang/String;Ljava/lang/String;F[L" TC3_PACKAGE_PATH
TC3_NAMED_VARIANT_CLASS_NAME_STR
";[B[L" TC3_PACKAGE_PATH TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR
@@ -157,39 +169,41 @@
TC3_ASSIGN_OR_RETURN(const jmethodID slot_class_constructor,
JniHelper::GetMethodID(env, slot_class.get(), "<init>",
"(Ljava/lang/String;IIIF)V"));
- TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> results,
- JniHelper::NewObjectArray(env, action_result.size(),
- result_class.get(), nullptr));
- for (int i = 0; i < action_result.size(); i++) {
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobjectArray> actions,
+ JniHelper::NewObjectArray(env, action_response.actions.size(),
+ action_class.get(), nullptr));
+ for (int i = 0; i < action_response.actions.size(); i++) {
ScopedLocalRef<jobjectArray> extras;
const reflection::Schema* actions_entity_data_schema =
context->model()->entity_data_schema();
if (actions_entity_data_schema != nullptr &&
- !action_result[i].serialized_entity_data.empty()) {
+ !action_response.actions[i].serialized_entity_data.empty()) {
TC3_ASSIGN_OR_RETURN(
extras, context->template_handler()->EntityDataAsNamedVariantArray(
actions_entity_data_schema,
- action_result[i].serialized_entity_data));
+ action_response.actions[i].serialized_entity_data));
}
ScopedLocalRef<jbyteArray> serialized_entity_data;
- if (!action_result[i].serialized_entity_data.empty()) {
+ if (!action_response.actions[i].serialized_entity_data.empty()) {
TC3_ASSIGN_OR_RETURN(
serialized_entity_data,
JniHelper::NewByteArray(
- env, action_result[i].serialized_entity_data.size()));
+ env, action_response.actions[i].serialized_entity_data.size()));
TC3_RETURN_IF_ERROR(JniHelper::SetByteArrayRegion(
env, serialized_entity_data.get(), 0,
- action_result[i].serialized_entity_data.size(),
+ action_response.actions[i].serialized_entity_data.size(),
reinterpret_cast<const jbyte*>(
- action_result[i].serialized_entity_data.data())));
+ action_response.actions[i].serialized_entity_data.data())));
}
ScopedLocalRef<jobjectArray> remote_action_templates_result;
if (generate_intents) {
std::vector<RemoteActionTemplate> remote_action_templates;
if (context->intent_generator()->GenerateIntents(
- device_locales, action_result[i], conversation, app_context,
+ device_locales, action_response.actions[i], conversation,
+ app_context,
/*annotations_entity_data_schema=*/annotations_entity_data_schema,
/*actions_entity_data_schema=*/actions_entity_data_schema,
&remote_action_templates)) {
@@ -202,19 +216,20 @@
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> reply,
context->jni_cache()->ConvertToJavaString(
- action_result[i].response_text));
+ action_response.actions[i].response_text));
TC3_ASSIGN_OR_RETURN(
ScopedLocalRef<jstring> action_type,
- JniHelper::NewStringUTF(env, action_result[i].type.c_str()));
+ JniHelper::NewStringUTF(env, action_response.actions[i].type.c_str()));
ScopedLocalRef<jobjectArray> slots;
- if (!action_result[i].slots.empty()) {
- TC3_ASSIGN_OR_RETURN(
- slots, JniHelper::NewObjectArray(env, action_result[i].slots.size(),
- slot_class.get(), nullptr));
- for (int j = 0; j < action_result[i].slots.size(); j++) {
- const Slot& slot_c = action_result[i].slots[j];
+ if (!action_response.actions[i].slots.empty()) {
+ TC3_ASSIGN_OR_RETURN(slots,
+ JniHelper::NewObjectArray(
+ env, action_response.actions[i].slots.size(),
+ slot_class.get(), nullptr));
+ for (int j = 0; j < action_response.actions[i].slots.size(); j++) {
+ const Slot& slot_c = action_response.actions[i].slots[j];
TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> slot_type,
JniHelper::NewStringUTF(env, slot_c.type.c_str()));
@@ -231,16 +246,28 @@
}
TC3_ASSIGN_OR_RETURN(
- ScopedLocalRef<jobject> result,
+ ScopedLocalRef<jobject> action,
JniHelper::NewObject(
- env, result_class.get(), result_class_constructor, reply.get(),
- action_type.get(), static_cast<jfloat>(action_result[i].score),
- extras.get(), serialized_entity_data.get(),
- remote_action_templates_result.get(), slots.get()));
+ env, action_class.get(), action_class_constructor, reply.get(),
+ action_type.get(),
+ static_cast<jfloat>(action_response.actions[i].score), extras.get(),
+ serialized_entity_data.get(), remote_action_templates_result.get(),
+ slots.get()));
TC3_RETURN_IF_ERROR(
- JniHelper::SetObjectArrayElement(env, results.get(), i, result.get()));
+ JniHelper::SetObjectArrayElement(env, actions.get(), i, action.get()));
}
- return results;
+
+ // Create the ActionSuggestions object.
+ TC3_ASSIGN_OR_RETURN(
+ const jmethodID result_class_constructor,
+ JniHelper::GetMethodID(env, result_class.get(), "<init>",
+ "([L" TC3_PACKAGE_PATH TC3_ACTIONS_CLASS_NAME_STR
+ "$ActionSuggestion;Z)V"));
+ TC3_ASSIGN_OR_RETURN(
+ ScopedLocalRef<jobject> result,
+ JniHelper::NewObject(env, result_class.get(), result_class_constructor,
+ actions.get(), action_response.is_sensitive));
+ return result;
}
StatusOr<ConversationMessage> FromJavaConversationMessage(JNIEnv* env,
@@ -387,7 +414,7 @@
} // namespace libtextclassifier3
using libtextclassifier3::ActionsSuggestionsJniContext;
-using libtextclassifier3::ActionSuggestionsToJObjectArray;
+using libtextclassifier3::ActionSuggestionsToJObject;
using libtextclassifier3::FromJavaActionSuggestionOptions;
using libtextclassifier3::FromJavaConversation;
using libtextclassifier3::JByteArrayToString;
@@ -468,7 +495,7 @@
#endif // TC3_UNILIB_JAVAICU
}
-TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+TC3_JNI_METHOD(jobject, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
jlong annotatorPtr, jobject app_context, jstring device_locales,
jboolean generate_intents) {
@@ -490,10 +517,10 @@
annotator ? annotator->entity_data_schema() : nullptr;
TC3_ASSIGN_OR_RETURN_NULL(
- ScopedLocalRef<jobjectArray> result,
- ActionSuggestionsToJObjectArray(
- env, context, app_context, anntotations_entity_data_schema,
- response.actions, conversation, device_locales, generate_intents));
+ ScopedLocalRef<jobject> result,
+ ActionSuggestionsToJObject(
+ env, context, app_context, anntotations_entity_data_schema, response,
+ conversation, device_locales, generate_intents));
return result.release();
}
diff --git a/native/actions/actions_jni.h b/native/actions/actions_jni.h
index 5265a9c..2d2d103 100644
--- a/native/actions/actions_jni.h
+++ b/native/actions/actions_jni.h
@@ -45,7 +45,7 @@
nativeInitializeConversationIntentDetection)
(JNIEnv* env, jobject thiz, jlong ptr, jbyteArray jserialized_config);
-TC3_JNI_METHOD(jobjectArray, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
+TC3_JNI_METHOD(jobject, TC3_ACTIONS_CLASS_NAME, nativeSuggestActions)
(JNIEnv* env, jobject thiz, jlong ptr, jobject jconversation, jobject joptions,
jlong annotatorPtr, jobject app_context, jstring device_locales,
jboolean generate_intents);
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index d900928..d122687 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index aa62c0a..2d97bc8 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index 50918e5..567828b 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
index b43e6d7..99f9040 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_emoji.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index 6a71da3..504d8e0 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
index 72f4d9d..33926c2 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_p13n.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
index a6c8118..730f603 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_tf2_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
index 4a120b2..29fe077 100644
--- a/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
+++ b/native/actions/test_data/actions_suggestions_test.sensitive_tflite.model
Binary files differ
diff --git a/native/actions/types.h b/native/actions/types.h
index 8e39d02..c400bb2 100644
--- a/native/actions/types.h
+++ b/native/actions/types.h
@@ -105,8 +105,8 @@
float sensitivity_score = -1.f;
float triggering_score = -1.f;
- // Whether the output was suppressed by the sensitivity threshold.
- bool output_filtered_sensitivity = false;
+ // Whether the input conversation is considered as sensitive.
+ bool is_sensitive = false;
// Whether the output was suppressed by the triggering score threshold.
bool output_filtered_min_triggering_score = false;