Export libtextclassifier
ExtServices APK now preloads universal model.
We mmap the model file straight from the APK, so we can't
compress those files when packing them to the APK.
Added "-0 model" as the appt flag to do so.
Bug: 169395238
Test: atest -p external/libtextclassifier
Change-Id: I3d59ec717a14f71be653159ecc7dd3e87bc9a80e
diff --git a/java/Android.bp b/java/Android.bp
index 893b423..981d7c9 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -35,6 +35,9 @@
sdk_version: "system_current",
min_sdk_version: "30",
manifest: "AndroidManifest.xml",
+ aaptflags: [
+ "-0 .model",
+ ],
}
// Similar to TextClassifierServiceLib, but without the AndroidManifest.
@@ -52,6 +55,10 @@
],
sdk_version: "system_current",
min_sdk_version: "30",
+ aaptflags: [
+ "-0 .model",
+ ],
+
}
java_library {
diff --git a/java/assets/textclassifier/actions_suggestions.universal.model b/java/assets/textclassifier/actions_suggestions.universal.model
new file mode 100755
index 0000000..f74fed4
--- /dev/null
+++ b/java/assets/textclassifier/actions_suggestions.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/annotator.universal.model b/java/assets/textclassifier/annotator.universal.model
new file mode 100755
index 0000000..09f1e0b
--- /dev/null
+++ b/java/assets/textclassifier/annotator.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/lang_id.model b/java/assets/textclassifier/lang_id.model
new file mode 100644
index 0000000..e94dada
--- /dev/null
+++ b/java/assets/textclassifier/lang_id.model
Binary files differ
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 7de1acc..8c33ffe 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -16,7 +16,10 @@
package com.android.textclassifier;
+import android.content.BroadcastReceiver;
import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
import android.os.CancellationSignal;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationActions;
@@ -55,6 +58,7 @@
private TextClassifierImpl textClassifier;
private TextClassifierSettings settings;
private ModelFileManager modelFileManager;
+ private BroadcastReceiver localeChangedReceiver;
public DefaultTextClassifierService() {
this.injector = new InjectorImpl(this);
@@ -76,14 +80,20 @@
normPriorityExecutor = injector.createNormPriorityExecutor();
lowPriorityExecutor = injector.createLowPriorityExecutor();
textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
+ localeChangedReceiver = new LocaleChangedReceiver(modelFileManager);
textClassifierApiUsageLogger =
injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
+
+ injector
+ .getContext()
+ .registerReceiver(localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
}
@Override
public void onDestroy() {
super.onDestroy();
+ injector.getContext().unregisterReceiver(localeChangedReceiver);
}
@Override
@@ -221,6 +231,22 @@
MoreExecutors.directExecutor());
}
+ /**
+ * Receiver listening to locale change event. Ask ModelFileManager to do clean-up upon receiving.
+ */
+ static class LocaleChangedReceiver extends BroadcastReceiver {
+ private final ModelFileManager modelFileManager;
+
+ LocaleChangedReceiver(ModelFileManager modelFileManager) {
+ this.modelFileManager = modelFileManager;
+ }
+
+ @Override
+ public void onReceive(Context context, Intent intent) {
+ modelFileManager.deleteUnusedModelFiles();
+ }
+ }
+
// Do not call any of these methods, except the constructor, before Service.onCreate is called.
private static class InjectorImpl implements Injector {
// Do not access the context object before Service.onCreate is invoked.
@@ -231,6 +257,11 @@
}
@Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
return new ModelFileManager(context, settings);
}
@@ -280,6 +311,8 @@
* class testable.
*/
interface Injector {
+ Context getContext();
+
ModelFileManager createModelFileManager(TextClassifierSettings settings);
TextClassifierSettings createTextClassifierSettings();
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
index 39a70bc..9bc31fb 100644
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -17,13 +17,15 @@
package com.android.textclassifier;
import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import android.content.res.AssetManager;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
-import android.text.TextUtils;
+import android.util.ArraySet;
import androidx.annotation.GuardedBy;
import androidx.annotation.StringDef;
-import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import androidx.collection.ArrayMap;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
import com.android.textclassifier.utils.IndentingPrintWriter;
@@ -31,64 +33,104 @@
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
-import com.google.common.base.Splitter;
+import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.Objects;
-import java.util.function.Function;
-import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
-// TODO(licha): Support garbage collection to delete unused model files
+// TODO(licha): Consider making this a singleton class
+// TODO(licha): Check whether this is thread-safe
/**
* Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
* model files to load.
*/
final class ModelFileManager {
+
private static final String TAG = "ModelFileManager";
+
private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
+ private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
+ private static final String ASSETS_DIR = "textclassifier";
- private final File downloadModelDir;
- private final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers;
+ private final List<ModelFileLister> modelFileListers;
+ private final File modelDownloaderDir;
- /** Create a ModelFileManager based on hardcoded model file locations. */
public ModelFileManager(Context context, TextClassifierSettings settings) {
Preconditions.checkNotNull(context);
Preconditions.checkNotNull(settings);
- this.downloadModelDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
- if (!downloadModelDir.exists()) {
- downloadModelDir.mkdirs();
- }
- ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>> suppliersBuilder =
- ImmutableMap.builder();
- for (String modelType : ModelType.values()) {
- suppliersBuilder.put(
- modelType, new ModelFileSupplierImpl(settings, modelType, downloadModelDir));
- }
- this.modelFileSuppliers = suppliersBuilder.build();
+ AssetManager assetManager = context.getAssets();
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ modelFileListers =
+ ImmutableList.of(
+ // Annotator models.
+ new RegularFilePatternMatchLister(
+ ModelType.ANNOTATOR,
+ this.modelDownloaderDir,
+ "annotator\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR,
+ new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ANNOTATOR,
+ ASSETS_DIR,
+ "annotator\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // Actions models.
+ new RegularFilePatternMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ this.modelDownloaderDir,
+ "actions_suggestions\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ACTIONS_SUGGESTIONS,
+ ASSETS_DIR,
+ "actions_suggestions\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // LangID models.
+ new RegularFilePatternMatchLister(
+ ModelType.LANG_ID,
+ this.modelDownloaderDir,
+ "lang_id\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.LANG_ID,
+ new File(CONFIG_UPDATER_DIR, "lang_id.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.LANG_ID,
+ ASSETS_DIR,
+ "lang_id.model",
+ /* isEnabled= */ () -> true));
}
@VisibleForTesting
- ModelFileManager(
- File downloadModelDir,
- ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers) {
- this.downloadModelDir = Preconditions.checkNotNull(downloadModelDir);
- this.modelFileSuppliers = Preconditions.checkNotNull(modelFileSuppliers);
+ ModelFileManager(Context context, List<ModelFileLister> modelFileListers) {
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ this.modelFileListers = ImmutableList.copyOf(modelFileListers);
}
/**
@@ -96,27 +138,203 @@
*
* @param modelType which type of model files to look for
*/
- public ImmutableList<ModelFile> listModelFiles(@ModelType.ModelTypeDef String modelType) {
- if (modelFileSuppliers.containsKey(modelType)) {
- return modelFileSuppliers.get(modelType).get();
+ public ImmutableList<ModelFile> listModelFiles(@ModelTypeDef String modelType) {
+ Preconditions.checkNotNull(modelType);
+
+ ImmutableList.Builder<ModelFile> modelFiles = new ImmutableList.Builder<>();
+ for (ModelFileLister modelFileLister : modelFileListers) {
+ modelFiles.addAll(modelFileLister.list(modelType));
}
- return ImmutableList.of();
+ return modelFiles.build();
+ }
+
+ /** Lists model files. */
+ public interface ModelFileLister {
+ List<ModelFile> list(@ModelTypeDef String modelType);
+ }
+
+ /** Lists model files by performing full match on file path. */
+ public static class RegularFileFullMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File targetFile;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param targetFile the expected model file
+ * @param isEnabled whether this lister is enabled
+ */
+ public RegularFileFullMatchLister(
+ @ModelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.targetFile = Preconditions.checkNotNull(targetFile);
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!targetFile.exists()) {
+ return ImmutableList.of();
+ }
+ try {
+ return ImmutableList.of(ModelFile.createFromRegularFile(targetFile, modelType));
+ } catch (IOException e) {
+ TcLog.e(
+ TAG, "Failed to call createFromRegularFile with: " + targetFile.getAbsolutePath(), e);
+ }
+ return ImmutableList.of();
+ }
+ }
+
+ /** Lists model file in a specified folder by doing pattern matching on file names. */
+ public static class RegularFilePatternMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File folder;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param folder the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether the lister is enabled
+ */
+ public RegularFilePatternMatchLister(
+ @ModelTypeDef String modelType,
+ File folder,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.folder = Preconditions.checkNotNull(folder);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!folder.isDirectory()) {
+ return ImmutableList.of();
+ }
+ File[] files = folder.listFiles();
+ if (files == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (File file : files) {
+ final Matcher matcher = fileNamePattern.matcher(file.getName());
+ if (!matcher.matches() || !file.isFile()) {
+ continue;
+ }
+ try {
+ modelFilesBuilder.add(ModelFile.createFromRegularFile(file, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromRegularFile with: " + file.getAbsolutePath());
+ }
+ }
+ return modelFilesBuilder.build();
+ }
+ }
+
+ /** Lists the model files preloaded in the APK file. */
+ public static class AssetFilePatternMatchLister implements ModelFileLister {
+ private final AssetManager assetManager;
+ private final String modelType;
+ private final String pathToList;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+ private final Object lock = new Object();
+ // Assets won't change without updating the app, so cache the result for performance reason.
+ @GuardedBy("lock")
+ private final Map<String, ImmutableList<ModelFile>> resultCache;
+
+ /**
+ * @param modelType the type of the model.
+ * @param pathToList the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether this lister is enabled
+ */
+ public AssetFilePatternMatchLister(
+ AssetManager assetManager,
+ @ModelTypeDef String modelType,
+ String pathToList,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.assetManager = Preconditions.checkNotNull(assetManager);
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.pathToList = Preconditions.checkNotNull(pathToList);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ resultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ synchronized (lock) {
+ if (resultCache.get(modelType) != null) {
+ return resultCache.get(modelType);
+ }
+ String[] fileNames = null;
+ try {
+ fileNames = assetManager.list(pathToList);
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to list assets", e);
+ }
+ if (fileNames == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (String fileName : fileNames) {
+ final Matcher matcher = fileNamePattern.matcher(fileName);
+ if (!matcher.matches()) {
+ continue;
+ }
+ String absolutePath =
+ new StringBuilder(pathToList).append('/').append(fileName).toString();
+ try {
+ modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromAsset with: " + absolutePath);
+ }
+ }
+ ImmutableList<ModelFile> result = modelFilesBuilder.build();
+ resultCache.put(modelType, result);
+ return result;
+ }
+ }
}
/**
* Returns the best model file for the given localelist, {@code null} if nothing is found.
*
* @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
- * @param localeList an ordered list of user preferences for locales, use {@code null} if there is
- * no preference.
+ * @param localePreferences an ordered list of user preferences for locales, use {@code null} if
+ * there is no preference.
*/
@Nullable
public ModelFile findBestModelFile(
- @ModelType.ModelTypeDef String modelType, @Nullable LocaleList localeList) {
+ @ModelTypeDef String modelType, @Nullable LocaleList localePreferences) {
final String languages =
- localeList == null || localeList.isEmpty()
+ localePreferences == null || localePreferences.isEmpty()
? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags();
+ : localePreferences.toLanguageTags();
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
ModelFile bestModel = null;
@@ -131,6 +349,34 @@
}
/**
+ * Deletes model files that are not preferred for any locales in user's preference.
+ *
+ * <p>This method will be invoked as a clean-up after we download a new model successfully. Race
+ * conditions are hard to avoid because we do not hold locks for files. But it should rarely cause
+ * any issues since it's safe to delete a model file in use (b/c we mmap it to memory).
+ */
+ public void deleteUnusedModelFiles() {
+ TcLog.d(TAG, "Start to delete unused model files.");
+ LocaleList localeList = LocaleList.getDefault();
+ for (@ModelTypeDef String modelType : ModelType.values()) {
+ ArraySet<ModelFile> allModelFiles = new ArraySet<>(listModelFiles(modelType));
+ for (int i = 0; i < localeList.size(); i++) {
+ // If a model file is preferred for any local in locale list, then keep it
+ ModelFile bestModel = findBestModelFile(modelType, new LocaleList(localeList.get(i)));
+ allModelFiles.remove(bestModel);
+ }
+ for (ModelFile modelFile : allModelFiles) {
+ if (modelFile.canWrite()) {
+ TcLog.d(TAG, "Deleting model: " + modelFile);
+ if (!modelFile.delete()) {
+ TcLog.w(TAG, "Failed to delete model: " + modelFile);
+ }
+ }
+ }
+ }
+ }
+
+ /**
* Returns a {@link File} that represents the destination to download a model.
*
* <p>Each model file's name is uniquely formatted based on its unique remote manifest URL.
@@ -140,9 +386,9 @@
* @param modelType the type of the model image to download
* @param manifestUrl the unique remote url of the model manifest
*/
- public File getDownloadTargetFile(@ModelType.ModelTypeDef String modelType, String manifestUrl) {
+ public File getDownloadTargetFile(@ModelTypeDef String modelType, String manifestUrl) {
String fileName = String.format("%s.%d.model", modelType, manifestUrl.hashCode());
- return new File(downloadModelDir, fileName);
+ return new File(modelDownloaderDir, fileName);
}
/**
@@ -153,7 +399,7 @@
public void dump(IndentingPrintWriter printWriter) {
printWriter.println("ModelFileManager:");
printWriter.increaseIndent();
- for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
+ for (@ModelTypeDef String modelType : ModelType.values()) {
printWriter.println(modelType + " model file(s):");
printWriter.increaseIndent();
for (ModelFile modelFile : listModelFiles(modelType)) {
@@ -164,260 +410,102 @@
printWriter.decreaseIndent();
}
- /** Default implementation of the model file supplier. */
- @VisibleForTesting
- static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> {
- private static final String FACTORY_MODEL_DIR = "/etc/textclassifier/";
+ /** Fetch metadata of a model file. */
+ private static class ModelInfoFetcher {
+ private final Function<AssetFileDescriptor, Integer> versionFetcher;
+ private final Function<AssetFileDescriptor, String> supportedLocalesFetcher;
- private static final class ModelFileInfo {
- private final String modelNameRegex;
- private final String configUpdaterModelPath;
- private final Function<Integer, Integer> versionSupplier;
- private final Function<Integer, String> supportedLocalesSupplier;
-
- public ModelFileInfo(
- String modelNameRegex,
- String configUpdaterModelPath,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- this.modelNameRegex = Preconditions.checkNotNull(modelNameRegex);
- this.configUpdaterModelPath = Preconditions.checkNotNull(configUpdaterModelPath);
- this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
- this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
- }
-
- public String getModelNameRegex() {
- return modelNameRegex;
- }
-
- public String getConfigUpdaterModelPath() {
- return configUpdaterModelPath;
- }
-
- public Function<Integer, Integer> getVersionSupplier() {
- return versionSupplier;
- }
-
- public Function<Integer, String> getSupportedLocalesSupplier() {
- return supportedLocalesSupplier;
- }
+ private ModelInfoFetcher(
+ Function<AssetFileDescriptor, Integer> versionFetcher,
+ Function<AssetFileDescriptor, String> supportedLocalesFetcher) {
+ this.versionFetcher = versionFetcher;
+ this.supportedLocalesFetcher = supportedLocalesFetcher;
}
- private static final ImmutableMap<String, ModelFileInfo> MODEL_FILE_INFO_MAP =
- ImmutableMap.<String, ModelFileInfo>builder()
- .put(
- ModelType.ANNOTATOR,
- new ModelFileInfo(
- "(annotator|textclassifier)\\.(.*)\\.model",
- "/data/misc/textclassifier/textclassifier.model",
- AnnotatorModel::getVersion,
- AnnotatorModel::getLocales))
- .put(
- ModelType.LANG_ID,
- new ModelFileInfo(
- "lang_id.model",
- "/data/misc/textclassifier/lang_id.model",
- LangIdModel::getVersion,
- fd -> ModelFile.LANGUAGE_INDEPENDENT))
- .put(
- ModelType.ACTIONS_SUGGESTIONS,
- new ModelFileInfo(
- "actions_suggestions\\.(.*)\\.model",
- "/data/misc/textclassifier/actions_suggestions.model",
- ActionsSuggestionsModel::getVersion,
- ActionsSuggestionsModel::getLocales))
- .build();
-
- private final TextClassifierSettings settings;
- @ModelType.ModelTypeDef private final String modelType;
- private final File configUpdaterModelFile;
- private final File downloaderModelDir;
- private final File factoryModelDir;
- private final Pattern modelFilenamePattern;
- private final Function<Integer, Integer> versionSupplier;
- private final Function<Integer, String> supportedLocalesSupplier;
- private final Object lock = new Object();
-
- @GuardedBy("lock")
- private ImmutableList<ModelFile> factoryModels;
-
- public ModelFileSupplierImpl(
- TextClassifierSettings settings,
- @ModelType.ModelTypeDef String modelType,
- File downloaderModelDir) {
- this(
- settings,
- modelType,
- new File(FACTORY_MODEL_DIR),
- MODEL_FILE_INFO_MAP.get(modelType).getModelNameRegex(),
- new File(MODEL_FILE_INFO_MAP.get(modelType).getConfigUpdaterModelPath()),
- downloaderModelDir,
- MODEL_FILE_INFO_MAP.get(modelType).getVersionSupplier(),
- MODEL_FILE_INFO_MAP.get(modelType).getSupportedLocalesSupplier());
+ int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return versionFetcher.apply(assetFileDescriptor);
}
- @VisibleForTesting
- ModelFileSupplierImpl(
- TextClassifierSettings settings,
- @ModelType.ModelTypeDef String modelType,
- File factoryModelDir,
- String modelFileNameRegex,
- File configUpdaterModelFile,
- File downloaderModelDir,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- this.settings = Preconditions.checkNotNull(settings);
- this.modelType = Preconditions.checkNotNull(modelType);
- this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
- this.modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(modelFileNameRegex));
- this.configUpdaterModelFile = Preconditions.checkNotNull(configUpdaterModelFile);
- this.downloaderModelDir = Preconditions.checkNotNull(downloaderModelDir);
- this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
- this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
+ String getSupportedLocales(AssetFileDescriptor assetFileDescriptor) {
+ return supportedLocalesFetcher.apply(assetFileDescriptor);
}
- @Override
- public ImmutableList<ModelFile> get() {
- final List<ModelFile> modelFiles = new ArrayList<>();
- // The dwonloader and config updater model have higher precedences.
- if (downloaderModelDir.exists() && settings.isModelDownloadManagerEnabled()) {
- modelFiles.addAll(getMatchedModelFiles(downloaderModelDir));
+ static ModelInfoFetcher create(@ModelTypeDef String modelType) {
+ switch (modelType) {
+ case ModelType.ANNOTATOR:
+ return new ModelInfoFetcher(AnnotatorModel::getVersion, AnnotatorModel::getLocales);
+ case ModelType.ACTIONS_SUGGESTIONS:
+ return new ModelInfoFetcher(
+ ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales);
+ case ModelType.LANG_ID:
+ return new ModelInfoFetcher(
+ LangIdModel::getVersion, afd -> ModelFile.LANGUAGE_INDEPENDENT);
+ default: // fall out
}
- if (configUpdaterModelFile.exists()) {
- final ModelFile updatedModel = createModelFile(configUpdaterModelFile);
- if (updatedModel != null) {
- modelFiles.add(updatedModel);
- }
- }
- // Factory models should never have overlapping locales, so the order doesn't matter.
- synchronized (lock) {
- if (factoryModels == null) {
- factoryModels = getMatchedModelFiles(factoryModelDir);
- }
- modelFiles.addAll(factoryModels);
- }
- return ImmutableList.copyOf(modelFiles);
- }
-
- private ImmutableList<ModelFile> getMatchedModelFiles(File parentDir) {
- ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
- if (parentDir.exists() && parentDir.isDirectory()) {
- final File[] files = parentDir.listFiles();
- for (File file : files) {
- final Matcher matcher = modelFilenamePattern.matcher(file.getName());
- if (matcher.matches() && file.isFile()) {
- final ModelFile model = createModelFile(file);
- if (model != null) {
- modelFilesBuilder.add(model);
- }
- }
- }
- }
- return modelFilesBuilder.build();
- }
-
- /** Returns null if the path did not point to a compatible model. */
- @Nullable
- private ModelFile createModelFile(File file) {
- if (!file.exists()) {
- return null;
- }
- ParcelFileDescriptor modelFd = null;
- try {
- modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- if (modelFd == null) {
- return null;
- }
- final int modelFdInt = modelFd.getFd();
- final int version = versionSupplier.apply(modelFdInt);
- final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt);
- if (supportedLocalesStr.isEmpty()) {
- TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
- return null;
- }
- final List<Locale> supportedLocales = new ArrayList<>();
- for (String langTag : Splitter.on(',').split(supportedLocalesStr)) {
- supportedLocales.add(Locale.forLanguageTag(langTag));
- }
- return new ModelFile(
- modelType,
- file,
- version,
- supportedLocales,
- supportedLocalesStr,
- ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
- return null;
- } finally {
- maybeCloseAndLogError(modelFd);
- }
- }
-
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
+ throw new IllegalStateException("Unsupported model types");
}
}
/** Describes TextClassifier model files on disk. */
- public static final class ModelFile {
- public static final String LANGUAGE_INDEPENDENT = "*";
+ public static class ModelFile {
+ @VisibleForTesting static final String LANGUAGE_INDEPENDENT = "*";
- @ModelType.ModelTypeDef private final String modelType;
- private final File file;
- private final int version;
- private final List<Locale> supportedLocales;
- private final String supportedLocalesStr;
- private final boolean languageIndependent;
+ @ModelTypeDef public final String modelType;
+ public final String absolutePath;
+ public final int version;
+ public final LocaleList supportedLocales;
+ public final boolean languageIndependent;
+ public final boolean isAsset;
- public ModelFile(
- @ModelType.ModelTypeDef String modelType,
- File file,
+ public static ModelFile createFromRegularFile(File file, @ModelTypeDef String modelType)
+ throws IOException {
+ ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ try (AssetFileDescriptor afd = new AssetFileDescriptor(pfd, 0, file.length())) {
+ return createFromAssetFileDescriptor(
+ file.getAbsolutePath(), modelType, afd, /* isAsset= */ false);
+ }
+ }
+
+ public static ModelFile createFromAsset(
+ AssetManager assetManager, String absolutePath, @ModelTypeDef String modelType)
+ throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = assetManager.openFd(absolutePath)) {
+ return createFromAssetFileDescriptor(
+ absolutePath, modelType, assetFileDescriptor, /* isAsset= */ true);
+ }
+ }
+
+ private static ModelFile createFromAssetFileDescriptor(
+ String absolutePath,
+ @ModelTypeDef String modelType,
+ AssetFileDescriptor assetFileDescriptor,
+ boolean isAsset) {
+ ModelInfoFetcher modelInfoFetcher = ModelInfoFetcher.create(modelType);
+ return new ModelFile(
+ modelType,
+ absolutePath,
+ modelInfoFetcher.getVersion(assetFileDescriptor),
+ modelInfoFetcher.getSupportedLocales(assetFileDescriptor),
+ isAsset);
+ }
+
+ @VisibleForTesting
+ ModelFile(
+ @ModelTypeDef String modelType,
+ String absolutePath,
int version,
- List<Locale> supportedLocales,
- String supportedLocalesStr,
- boolean languageIndependent) {
- this.modelType = Preconditions.checkNotNull(modelType);
- this.file = Preconditions.checkNotNull(file);
+ String supportedLocaleTags,
+ boolean isAsset) {
+ this.modelType = modelType;
+ this.absolutePath = absolutePath;
this.version = version;
- this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
- this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
- this.languageIndependent = languageIndependent;
- }
-
- /** Returns the type of this model, defined in {@link ModelType}. */
- @ModelType.ModelTypeDef
- public String getModelType() {
- return modelType;
- }
-
- /** Returns the absolute path to the model file. */
- public String getPath() {
- return file.getAbsolutePath();
- }
-
- /** Returns a name to use for id generation, effectively the name of the model file. */
- public String getName() {
- return file.getName();
- }
-
- /** Returns the version tag in the model's metadata. */
- public int getVersion() {
- return version;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- Preconditions.checkNotNull(languageRanges);
- return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
+ this.languageIndependent = LANGUAGE_INDEPENDENT.equals(supportedLocaleTags);
+ this.supportedLocales =
+ languageIndependent
+ ? LocaleList.getEmptyLocaleList()
+ : LocaleList.forLanguageTags(supportedLocaleTags);
+ this.isAsset = isAsset;
}
/** Returns if this model file is preferred to the given one. */
@@ -437,70 +525,111 @@
}
// A higher-version model is preferred.
- if (version > model.getVersion()) {
+ if (version > model.version) {
return true;
}
return false;
}
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ if (languageIndependent) {
+ return true;
+ }
+ List<String> supportedLocaleTags =
+ Arrays.asList(supportedLocales.toLanguageTags().split(","));
+ return Locale.lookupTag(languageRanges, supportedLocaleTags) != null;
+ }
+
+ public AssetFileDescriptor open(AssetManager assetManager) throws IOException {
+ if (isAsset) {
+ return assetManager.openFd(absolutePath);
+ }
+ File file = new File(absolutePath);
+ ParcelFileDescriptor parcelFileDescriptor =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ return new AssetFileDescriptor(parcelFileDescriptor, 0, file.length());
+ }
+
+ public boolean canWrite() {
+ if (isAsset) {
+ return false;
+ }
+ return new File(absolutePath).canWrite();
+ }
+
+ public boolean delete() {
+ if (isAsset) {
+ throw new IllegalStateException("asset is read-only, deleting it is not allowed.");
+ }
+ return new File(absolutePath).delete();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof ModelFile)) {
+ return false;
+ }
+ ModelFile modelFile = (ModelFile) o;
+ return version == modelFile.version
+ && languageIndependent == modelFile.languageIndependent
+ && isAsset == modelFile.isAsset
+ && Objects.equals(modelType, modelFile.modelType)
+ && Objects.equals(absolutePath, modelFile.absolutePath)
+ && Objects.equals(supportedLocales, modelFile.supportedLocales);
+ }
+
@Override
public int hashCode() {
- return Objects.hash(getPath());
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
- }
- if (other instanceof ModelFile) {
- final ModelFile otherModel = (ModelFile) other;
- return TextUtils.equals(getPath(), otherModel.getPath());
- }
- return false;
+ return Objects.hash(
+ modelType, absolutePath, version, supportedLocales, languageIndependent, isAsset);
}
public ModelInfo toModelInfo() {
- return new ModelInfo(getVersion(), supportedLocalesStr);
+ return new ModelInfo(version, supportedLocales.toLanguageTags());
}
@Override
public String toString() {
return String.format(
Locale.US,
- "ModelFile { type=%s path=%s name=%s version=%d locales=%s }",
+ "ModelFile { type=%s path=%s version=%d locales=%s isAsset=%b}",
modelType,
- getPath(),
- getName(),
+ absolutePath,
version,
- supportedLocalesStr);
+ languageIndependent ? LANGUAGE_INDEPENDENT : supportedLocales.toLanguageTags(),
+ isAsset);
}
public static ImmutableList<Optional<ModelInfo>> toModelInfos(
- Optional<ModelFile>... modelFiles) {
+ Optional<ModelFileManager.ModelFile>... modelFiles) {
return Arrays.stream(modelFiles)
- .map(modelFile -> modelFile.transform(ModelFile::toModelInfo))
+ .map(modelFile -> modelFile.transform(ModelFileManager.ModelFile::toModelInfo))
.collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
}
+ }
- /** Effectively an enum class to represent types of models. */
- public static final class ModelType {
- @Retention(RetentionPolicy.SOURCE)
- @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
- public @interface ModelTypeDef {}
+ /** Effectively an enum class to represent types of models. */
+ public static final class ModelType {
+ @Retention(RetentionPolicy.SOURCE)
+ @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
+ @interface ModelTypeDef {}
- public static final String ANNOTATOR = "annotator";
- public static final String LANG_ID = "lang_id";
- public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
+ public static final String ANNOTATOR = "annotator";
+ public static final String LANG_ID = "lang_id";
+ public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
- public static final ImmutableList<String> VALUES =
- ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
+ public static final ImmutableList<String> VALUES =
+ ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
- public static ImmutableList<String> values() {
- return VALUES;
- }
-
- private ModelType() {}
+ public static ImmutableList<String> values() {
+ return VALUES;
}
+
+ private ModelType() {}
}
}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index dbe66f6..b824ed0 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -22,6 +22,7 @@
import android.app.RemoteAction;
import android.content.Context;
import android.content.Intent;
+import android.content.res.AssetFileDescriptor;
import android.icu.util.ULocale;
import android.os.Bundle;
import android.os.LocaleList;
@@ -42,7 +43,7 @@
import androidx.annotation.WorkerThread;
import androidx.core.util.Pair;
import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
@@ -61,6 +62,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
+import java.io.IOException;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
@@ -123,7 +125,7 @@
}
@WorkerThread
- TextSelection suggestSelection(TextSelection.Request request) {
+ TextSelection suggestSelection(TextSelection.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
final int rangeLength = request.getEndIndex() - request.getStartIndex();
@@ -182,7 +184,7 @@
}
@WorkerThread
- TextClassification classifyText(TextClassification.Request request) {
+ TextClassification classifyText(TextClassification.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
LangIdModel langId = getLangIdImpl();
@@ -222,7 +224,7 @@
}
@WorkerThread
- TextLinks generateLinks(TextLinks.Request request) {
+ TextLinks generateLinks(TextLinks.Request request) throws IOException {
Preconditions.checkNotNull(request);
Preconditions.checkArgument(
request.getText().length() <= getMaxGenerateLinksTextLength(),
@@ -332,7 +334,7 @@
TextClassifierEventConverter.fromPlatform(event));
}
- TextLanguage detectLanguage(TextLanguage.Request request) {
+ TextLanguage detectLanguage(TextLanguage.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
final TextLanguage.Builder builder = new TextLanguage.Builder();
@@ -345,7 +347,8 @@
return builder.build();
}
- ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ ConversationActions suggestConversationActions(ConversationActions.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
ActionsSuggestionsModel actionsImpl = getActionsImpl();
@@ -431,7 +434,7 @@
return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
}
- private AnnotatorModel getAnnotatorImpl(LocaleList localeList) {
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws IOException {
synchronized (lock) {
localeList = localeList == null ? LocaleList.getDefault() : localeList;
final ModelFileManager.ModelFile bestModel =
@@ -444,31 +447,35 @@
// The current annotator model may be still used by another thread / model.
// Do not call close() here, and let the GC to clean it up when no one else
// is using it.
- annotatorImpl = new AnnotatorModel(bestModel.getPath());
- annotatorImpl.setLangIdModel(getLangIdImpl());
- annotatorModelInUse = bestModel;
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ annotatorImpl = new AnnotatorModel(afd);
+ annotatorImpl.setLangIdModel(getLangIdImpl());
+ annotatorModelInUse = bestModel;
+ }
}
return annotatorImpl;
}
}
- private LangIdModel getLangIdImpl() {
+ private LangIdModel getLangIdImpl() throws IOException {
synchronized (lock) {
final ModelFileManager.ModelFile bestModel =
- modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localeList= */ null);
+ modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localePreferences= */ null);
if (bestModel == null) {
throw new IllegalStateException("Failed to find the best LangID model.");
}
if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- langIdImpl = new LangIdModel(bestModel.getPath());
- langIdModelInUse = bestModel;
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ langIdImpl = new LangIdModel(afd);
+ langIdModelInUse = bestModel;
+ }
}
return langIdImpl;
}
}
- private ActionsSuggestionsModel getActionsImpl() {
+ private ActionsSuggestionsModel getActionsImpl() throws IOException {
synchronized (lock) {
// TODO: Use LangID to determine the locale we should use here?
final ModelFileManager.ModelFile bestModel =
@@ -479,8 +486,10 @@
}
if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- actionsImpl = new ActionsSuggestionsModel(bestModel.getPath());
- actionModelInUse = bestModel;
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ actionsImpl = new ActionsSuggestionsModel(afd);
+ actionModelInUse = bestModel;
+ }
}
return actionsImpl;
}
diff --git a/java/src/com/android/textclassifier/TextClassifierSettings.java b/java/src/com/android/textclassifier/TextClassifierSettings.java
index c13ff68..b13d166 100644
--- a/java/src/com/android/textclassifier/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/TextClassifierSettings.java
@@ -20,7 +20,8 @@
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
import androidx.annotation.NonNull;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
@@ -44,6 +45,7 @@
* @see android.provider.DeviceConfig#NAMESPACE_TEXTCLASSIFIER
*/
public final class TextClassifierSettings {
+ private static final String TAG = "TextClassifierSettings";
public static final String NAMESPACE = DeviceConfig.NAMESPACE_TEXTCLASSIFIER;
private static final String DELIMITER = ":";
@@ -109,6 +111,9 @@
/** Whether to enable model downloading with ModelDownloadManager */
@VisibleForTesting
static final String MODEL_DOWNLOAD_MANAGER_ENABLED = "model_download_manager_enabled";
+ /** Type of network to download model manifest. A String value of androidx.work.NetworkType. */
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE =
+ "manifest_download_required_network_type";
/** Max attempts allowed for a single ModelDownloader downloading task. */
@VisibleForTesting
static final String MODEL_DOWNLOAD_MAX_ATTEMPTS = "model_download_max_attempts";
@@ -196,6 +201,8 @@
private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
+ // Manifest files are usually small, default to any network type
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "NOT_ROAMING";
private static final int MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 5;
private static final String ANNOTATOR_URL_PREFIX_DEFAULT =
"https://www.gstatic.com/android/text_classifier/";
@@ -360,7 +367,7 @@
NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
}
- public String getModelURLPrefix(@ModelType.ModelTypeDef String modelType) {
+ public String getModelURLPrefix(@ModelTypeDef String modelType) {
switch (modelType) {
case ModelType.ANNOTATOR:
return deviceConfig.getString(
@@ -375,7 +382,7 @@
}
}
- public String getPrimaryModelURLSuffix(@ModelType.ModelTypeDef String modelType) {
+ public String getPrimaryModelURLSuffix(@ModelTypeDef String modelType) {
switch (modelType) {
case ModelType.ANNOTATOR:
return deviceConfig.getString(
diff --git a/java/src/com/android/textclassifier/common/base/TcLog.java b/java/src/com/android/textclassifier/common/base/TcLog.java
index 87f1187..05a2443 100644
--- a/java/src/com/android/textclassifier/common/base/TcLog.java
+++ b/java/src/com/android/textclassifier/common/base/TcLog.java
@@ -16,6 +16,8 @@
package com.android.textclassifier.common.base;
+import android.util.Log;
+
/**
* Logging for android.view.textclassifier package.
*
@@ -31,27 +33,30 @@
public static final String TAG = "androidtc";
/** true: Enables full logging. false: Limits logging to debug level. */
- public static final boolean ENABLE_FULL_LOGGING =
- android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
+ public static final boolean ENABLE_FULL_LOGGING = Log.isLoggable(TAG, Log.VERBOSE);
private TcLog() {}
public static void v(String tag, String msg) {
if (ENABLE_FULL_LOGGING) {
- android.util.Log.v(getTag(tag), msg);
+ Log.v(getTag(tag), msg);
}
}
public static void d(String tag, String msg) {
- android.util.Log.d(getTag(tag), msg);
+ Log.d(getTag(tag), msg);
}
public static void w(String tag, String msg) {
- android.util.Log.w(getTag(tag), msg);
+ Log.w(getTag(tag), msg);
+ }
+
+ public static void e(String tag, String msg) {
+ Log.e(getTag(tag), msg);
}
public static void e(String tag, String msg, Throwable tr) {
- android.util.Log.e(getTag(tag), msg, tr);
+ Log.e(getTag(tag), msg, tr);
}
private static String getTag(String customTag) {
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index a0cd0ec..fa31894 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -55,5 +55,5 @@
instrumentation_for: "TextClassifierService",
- data: ["testdata/*"]
+ data: ["testdata/*"],
}
\ No newline at end of file
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 09bd363..20fa508 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -43,7 +43,6 @@
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.util.List;
@@ -190,7 +189,7 @@
@Test
public void missingModelFile_onFailureShouldBeCalled() throws Exception {
testInjector.setModelFileManager(
- new ModelFileManager(TestDataUtils.getTestDataFolder(), ImmutableMap.of()));
+ new ModelFileManager(ApplicationProvider.getApplicationContext(), ImmutableList.of()));
defaultTextClassifierService.onCreate();
TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
@@ -231,9 +230,14 @@
}
@Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
if (modelFileManager == null) {
- return TestDataUtils.createModelFileManagerForTesting();
+ return TestDataUtils.createModelFileManagerForTesting(context);
}
return modelFileManager;
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
index 8ef3908..de819ef 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -16,29 +16,31 @@
package com.android.textclassifier;
+import static com.android.textclassifier.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.ArgumentMatchers.anyBoolean;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.when;
import android.os.LocaleList;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManager.RegularFilePatternMatchLister;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
+import com.google.common.io.Files;
import java.io.File;
import java.io.IOException;
-import java.util.Collections;
import java.util.List;
import java.util.Locale;
-import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
@@ -51,22 +53,16 @@
private static final String URL = "http://www.gstatic.com/android/text_classifier/q/711/en.fb";
private static final String URL_2 = "http://www.gstatic.com/android/text_classifier/q/712/en.fb";
- @ModelFile.ModelType.ModelTypeDef
- private static final String MODEL_TYPE = ModelFile.ModelType.ANNOTATOR;
+ @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
- @ModelFile.ModelType.ModelTypeDef
- private static final String MODEL_TYPE_2 = ModelFile.ModelType.LANG_ID;
+ @ModelTypeDef private static final String MODEL_TYPE_2 = ModelType.LANG_ID;
- @Mock private Supplier<ImmutableList<ModelFile>> modelFileSupplier;
@Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
- private File rootTestDir;
- private File factoryModelDir;
- private File configUpdaterModelFile;
- private File downloaderModelDir;
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+ private File rootTestDir;
private ModelFileManager modelFileManager;
- private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
@Before
public void setup() {
@@ -74,28 +70,11 @@
rootTestDir =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
- factoryModelDir = new File(rootTestDir, "factory");
- configUpdaterModelFile = new File(rootTestDir, "configupdater.model");
- downloaderModelDir = new File(rootTestDir, "downloader");
-
- modelFileManager =
- new ModelFileManager(downloaderModelDir, ImmutableMap.of(MODEL_TYPE, modelFileSupplier));
- modelFileSupplierImpl =
- new ModelFileManager.ModelFileSupplierImpl(
- new TextClassifierSettings(mockDeviceConfig),
- MODEL_TYPE,
- factoryModelDir,
- "test\\d.model",
- configUpdaterModelFile,
- downloaderModelDir,
- fd -> 1,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
-
rootTestDir.mkdirs();
- factoryModelDir.mkdirs();
- downloaderModelDir.mkdirs();
-
- Locale.setDefault(DEFAULT_LOCALE);
+ modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ new TextClassifierSettings(mockDeviceConfig));
}
@After
@@ -104,100 +83,106 @@
}
@Test
- public void get() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(modelFile));
+ public void annotatorModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
+ }
- List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(MODEL_TYPE);
+ @Test
+ public void actionsModelPreloaded() {
+ verifyModelPreloadedAsAsset(
+ ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
+ }
- assertThat(modelFiles).hasSize(1);
- assertThat(modelFiles.get(0)).isEqualTo(modelFile);
+ @Test
+ public void langIdModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
+ }
+
+ private void verifyModelPreloadedAsAsset(
+ @ModelTypeDef String modelType, String expectedModelPath) {
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
+ List<ModelFile> assetFiles =
+ modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
+
+ assertThat(assetFiles).hasSize(1);
+ assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
}
@Test
public void findBestModel_versionCode() {
ModelFileManager.ModelFile olderModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile newerModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(olderModelFile, newerModelFile));
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(olderModelFile, newerModelFile)));
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.getEmptyLocaleList());
-
+ ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
assertThat(bestModelFile).isEqualTo(newerModelFile);
}
@Test
public void findBestModel_languageDependentModelIsPreferred() {
- Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ DEFAULT_LOCALE.toLanguageTag(),
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(
- MODEL_TYPE, LocaleList.forLanguageTags(locale.toLanguageTag()));
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
}
@Test
public void findBestModel_noMatchedLanguageModel() {
- Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(DEFAULT_LOCALE),
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
DEFAULT_LOCALE.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
@@ -209,23 +194,22 @@
ModelFileManager.ModelFile matchButOlderModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("fr")),
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
"fr",
- false);
-
+ /* isAsset= */ false);
ModelFileManager.ModelFile mismatchButNewerModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
"ja",
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchButOlderModel, mismatchButNewerModel));
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchButOlderModel, mismatchButNewerModel)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
@@ -233,21 +217,26 @@
}
@Test
- public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
+ public void findBestModel_preferMatchedLocaleModel() {
ModelFileManager.ModelFile matchLocaleModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
"ja",
- false);
-
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageIndependentModel =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchLocaleModel, languageIndependentModel));
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchLocaleModel, languageIndependentModel)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
@@ -256,9 +245,135 @@
}
@Test
+ public void deleteUnusedModelFiles_olderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "ja", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_languageIndependentOlderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model1.getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model2.getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_modelOnlySupportingLocalesNotInListDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 1, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isFalse();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_multiLocalesInLocaleList() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("en")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_readOnlyModelsUntouched() throws Exception {
+ File readOnlyDir = new File(rootTestDir, "read_only/");
+ readOnlyDir.mkdirs();
+ File model1 = new File(readOnlyDir, "model1.fb");
+ model1.createNewFile();
+ readOnlyDir.setWritable(false);
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile)));
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ }
+
+ @Test
public void getDownloadTargetFile_targetFileInCorrectDir() {
File targetFile = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL);
- assertThat(targetFile.getParentFile()).isEqualTo(downloaderModelDir);
+ assertThat(targetFile.getAbsolutePath())
+ .startsWith(ApplicationProvider.getApplicationContext().getFilesDir().getAbsolutePath());
}
@Test
@@ -278,21 +393,11 @@
public void modelFileEquals() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA).isEqualTo(modelB);
}
@@ -301,67 +406,23 @@
public void modelFile_different() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA).isNotEqualTo(modelB);
}
@Test
- public void modelFile_getPath() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getPath()).isEqualTo("/path/a");
- }
-
- @Test
- public void modelFile_getName() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getName()).isEqualTo("a");
- }
-
- @Test
public void modelFile_isPreferredTo_languageDependentIsBetter() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
+ MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -370,16 +431,11 @@
public void modelFile_isPreferredTo_version() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 1, ImmutableList.of(), "", false);
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -388,7 +444,7 @@
public void modelFile_toModelInfo() {
ModelFileManager.ModelFile modelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ModelInfo modelInfo = modelFile.toModelInfo();
@@ -398,11 +454,9 @@
@Test
public void modelFile_toModelInfos() {
ModelFile englishModelFile =
- new ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(Locale.ENGLISH), "en", false);
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
ModelFile japaneseModelFile =
- new ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ImmutableList<Optional<ModelInfo>> modelInfos =
ModelFileManager.ModelFile.toModelInfos(
@@ -417,64 +471,53 @@
}
@Test
- public void testFileSupplierImpl_updatedFileOnly() throws IOException {
- when(mockDeviceConfig.getBoolean(
- eq(TextClassifierSettings.NAMESPACE),
- eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
- anyBoolean()))
- .thenReturn(false);
- configUpdaterModelFile.createNewFile();
- File downloaderModelFile = new File(downloaderModelDir, "test0.model");
- downloaderModelFile.createNewFile();
- File model1 = new File(factoryModelDir, "test1.model");
- model1.createNewFile();
- File model2 = new File(factoryModelDir, "test2.model");
- model2.createNewFile();
- new File(factoryModelDir, "not_match_regex.model").createNewFile();
+ public void regularFileFullMatchLister() throws IOException {
+ File modelFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
+ File wrongFile = new File(rootTestDir, "wrong.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream().map(modelFile -> modelFile.getPath()).collect(Collectors.toList());
+ RegularFileFullMatchLister regularFileFullMatchLister =
+ new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
+ ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- configUpdaterModelFile.getAbsolutePath(),
- model1.getAbsolutePath(),
- model2.getAbsolutePath());
+ assertThat(listedModels).hasSize(1);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
}
@Test
- public void testFileSupplierImpl_includeDownloaderFile() throws IOException {
- when(mockDeviceConfig.getBoolean(
- eq(TextClassifierSettings.NAMESPACE),
- eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
- anyBoolean()))
- .thenReturn(true);
- configUpdaterModelFile.createNewFile();
- File downloaderModelFile = new File(downloaderModelDir, "test0.model");
- downloaderModelFile.createNewFile();
- File factoryModelFile = new File(factoryModelDir, "test1.model");
- factoryModelFile.createNewFile();
+ public void regularFilePatternMatchLister() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+ File modelFile2 = new File(rootTestDir, "annotator.fr.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
+ File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream().map(ModelFile::getPath).collect(Collectors.toList());
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- configUpdaterModelFile.getAbsolutePath(),
- downloaderModelFile.getAbsolutePath(),
- factoryModelFile.getAbsolutePath());
+ assertThat(listedModels).hasSize(2);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile1.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ assertThat(listedModels.get(1).absolutePath).isEqualTo(modelFile2.getAbsolutePath());
+ assertThat(listedModels.get(1).isAsset).isFalse();
}
@Test
- public void testFileSupplierImpl_empty() {
- factoryModelDir.delete();
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+ public void regularFilePatternMatchLister_disabled() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
- assertThat(modelFiles).hasSize(0);
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).isEmpty();
}
private static void recursiveDelete(File f) {
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index 1acdebf..7565a0b 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -16,63 +16,44 @@
package com.android.textclassifier;
-import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import android.content.Context;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import java.io.File;
-import java.util.Locale;
-import java.util.function.Supplier;
/** Utils to access test data files. */
public final class TestDataUtils {
- private static final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>>
- MODEL_FILES_SUPPLIER =
- new ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>>()
- .put(
- ModelType.ANNOTATOR,
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.ANNOTATOR,
- new File(
- TestDataUtils.getTestDataFolder(), "testdata/annotator.model"),
- 711,
- ImmutableList.of(Locale.ENGLISH),
- "en",
- false)))
- .put(
- ModelType.ACTIONS_SUGGESTIONS,
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.ACTIONS_SUGGESTIONS,
- new File(TestDataUtils.getTestDataFolder(), "testdata/actions.model"),
- 104,
- ImmutableList.of(Locale.ENGLISH),
- "en",
- false)))
- .put(
- ModelType.LANG_ID,
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.LANG_ID,
- new File(TestDataUtils.getTestDataFolder(), "testdata/langid.model"),
- 1,
- ImmutableList.of(),
- "*",
- true)))
- .build();
+ private static final String TEST_ANNOTATOR_MODEL_PATH = "testdata/annotator.model";
+ private static final String TEST_ACTIONS_MODEL_PATH = "testdata/actions.model";
+ private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
/** Returns the root folder that contains the test data. */
public static File getTestDataFolder() {
return new File("/data/local/tmp/TextClassifierServiceTest/");
}
- public static ModelFileManager createModelFileManagerForTesting() {
+ public static File getTestAnnotatorModelFile() {
+ return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
+ }
+
+ public static File getTestActionsModelFile() {
+ return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
+ }
+
+ public static File getLangIdModelFile() {
+ return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
+ }
+
+ public static ModelFileManager createModelFileManagerForTesting(Context context) {
return new ModelFileManager(
- /* downloadModelDir= */ TestDataUtils.getTestDataFolder(), MODEL_FILES_SUPPLIER);
+ context,
+ ImmutableList.of(
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
+ new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)));
}
private TestDataUtils() {}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
new file mode 100644
index 0000000..0031368
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -0,0 +1,197 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.UiAutomation;
+import android.icu.util.ULocale;
+import android.provider.DeviceConfig;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextLinks.TextLink;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExternalResource;
+import org.junit.runner.RunWith;
+
+/**
+ * End-to-end tests for the {@link TextClassifier} APIs. Unlike {@link TextClassifierImplTest}.
+ *
+ * <p>Unlike {@link TextClassifierImplTest}, we are trying to run the tests in a environment that is
+ * closer to the production environment. For example, we are not injecting the model files.
+ */
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierApiTest {
+
+ private TextClassifier textClassifier;
+
+ @Rule
+ public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
+ new ExtServicesTextClassifierRule();
+
+ @Before
+ public void setup() {
+ textClassifier = extServicesTextClassifierRule.getTextClassifier();
+ }
+
+ @Test
+ public void suggestSelection() {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
+
+ TextSelection selection = textClassifier.suggestSelection(request);
+ assertThat(selection.getEntityCount()).isGreaterThan(0);
+ assertThat(selection.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(selection.getSelectionStartIndex()).isEqualTo(smartStartIndex);
+ assertThat(selection.getSelectionEndIndex()).isEqualTo(smartEndIndex);
+ }
+
+ @Test
+ public void classifyText() {
+ String text = "Contact me at droid@android.com";
+ String classifiedText = "droid@android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
+
+ TextClassification classification = textClassifier.classifyText(request);
+ assertThat(classification.getEntityCount()).isGreaterThan(0);
+ assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+ assertThat(classification.getText()).isEqualTo(classifiedText);
+ assertThat(classification.getActions().size()).isGreaterThan(1);
+ }
+
+ @Test
+ public void generateLinks() {
+ String text = "Check this out, http://www.android.com";
+
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ TextLinks textLinks = textClassifier.generateLinks(request);
+
+ List<TextLink> links = new ArrayList<>(textLinks.getLinks());
+ assertThat(textLinks.getText().toString()).isEqualTo(text);
+ assertThat(links).hasSize(1);
+ assertThat(links.get(0).getEntityCount()).isGreaterThan(0);
+ assertThat(links.get(0).getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(links.get(0).getConfidenceScore(TextClassifier.TYPE_URL)).isGreaterThan(0);
+ }
+
+ @Test
+ public void detectedLanguage() {
+ String text = "朝、ピカチュウ";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+
+ TextLanguage textLanguage = textClassifier.detectLanguage(request);
+
+ assertThat(textLanguage.getLocaleHypothesisCount()).isGreaterThan(0);
+ assertThat(textLanguage.getLocale(0).getLanguage()).isEqualTo("ja");
+ assertThat(textLanguage.getConfidenceScore(ULocale.JAPANESE)).isGreaterThan(0);
+ }
+
+ @Test
+ public void suggestConversationActions() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Check this out: https://www.android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
+
+ ConversationActions conversationActions = textClassifier.suggestConversationActions(request);
+
+ assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ assertThat(conversationAction.getAction()).isNotNull();
+ }
+
+ /** A rule that manages a text classifier that is backed by the ExtServices. */
+ private static class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
+ "textclassifier_service_package_override";
+
+ private String textClassifierServiceOverrideFlagOldValue;
+
+ @Override
+ protected void before() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ textClassifierServiceOverrideFlagOldValue =
+ DeviceConfig.getString(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ null);
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ "com.google.android.ext.services",
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ @Override
+ protected void after() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ textClassifierServiceOverrideFlagOldValue,
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ public TextClassifier getTextClassifier() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ return textClassificationManager.getTextClassifier();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 5007e2a..06ec640 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -38,10 +38,12 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.testing.FakeContextBuilder;
import com.google.common.collect.ImmutableList;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -63,7 +65,7 @@
private TextClassifierImpl classifier;
private final ModelFileManager modelFileManager =
- TestDataUtils.createModelFileManagerForTesting();
+ TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
@Before
public void setup() {
@@ -77,7 +79,7 @@
}
@Test
- public void testSuggestSelection() {
+ public void testSuggestSelection() throws IOException {
String text = "Contact me at droid@android.com";
String selected = "droid";
String suggested = "droid@android.com";
@@ -96,7 +98,7 @@
}
@Test
- public void testSuggestSelection_url() {
+ public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
String suggested = "http://www.android.com";
@@ -114,7 +116,7 @@
}
@Test
- public void testSmartSelection_withEmoji() {
+ public void testSmartSelection_withEmoji() throws IOException {
String text = "\uD83D\uDE02 Hello.";
String selected = "Hello";
int startIndex = text.indexOf(selected);
@@ -129,7 +131,7 @@
}
@Test
- public void testClassifyText() {
+ public void testClassifyText() throws IOException {
String text = "Contact me at droid@android.com";
String classifiedText = "droid@android.com";
int startIndex = text.indexOf(classifiedText);
@@ -144,7 +146,7 @@
}
@Test
- public void testClassifyText_url() {
+ public void testClassifyText_url() throws IOException {
String text = "Visit www.android.com for more information";
String classifiedText = "www.android.com";
int startIndex = text.indexOf(classifiedText);
@@ -160,7 +162,7 @@
}
@Test
- public void testClassifyText_address() {
+ public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
new TextClassification.Request.Builder(text, 0, text.length())
@@ -172,7 +174,7 @@
}
@Test
- public void testClassifyText_url_inCaps() {
+ public void testClassifyText_url_inCaps() throws IOException {
String text = "Visit HTTP://ANDROID.COM for more information";
String classifiedText = "HTTP://ANDROID.COM";
int startIndex = text.indexOf(classifiedText);
@@ -188,7 +190,7 @@
}
@Test
- public void testClassifyText_date() {
+ public void testClassifyText_date() throws IOException {
String text = "Let's meet on January 9, 2018.";
String classifiedText = "January 9, 2018";
int startIndex = text.indexOf(classifiedText);
@@ -209,7 +211,7 @@
}
@Test
- public void testClassifyText_datetime() {
+ public void testClassifyText_datetime() throws IOException {
String text = "Let's meet 2018/01/01 10:30:20.";
String classifiedText = "2018/01/01 10:30:20";
int startIndex = text.indexOf(classifiedText);
@@ -224,7 +226,7 @@
}
@Test
- public void testClassifyText_foreignText() {
+ public void testClassifyText_foreignText() throws IOException {
LocaleList originalLocales = LocaleList.getDefault();
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
@@ -253,7 +255,7 @@
}
@Test
- public void testGenerateLinks_phone() {
+ public void testGenerateLinks_phone() throws IOException {
String text = "The number is +12122537077. See you tonight!";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
assertThat(
@@ -262,7 +264,7 @@
}
@Test
- public void testGenerateLinks_exclude() {
+ public void testGenerateLinks_exclude() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
@@ -278,7 +280,7 @@
}
@Test
- public void testGenerateLinks_explicit_address() {
+ public void testGenerateLinks_explicit_address() throws IOException {
String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
TextLinks.Request request =
@@ -293,7 +295,7 @@
}
@Test
- public void testGenerateLinks_exclude_override() {
+ public void testGenerateLinks_exclude_override() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
@@ -309,7 +311,7 @@
}
@Test
- public void testGenerateLinks_maxLength() {
+ public void testGenerateLinks_maxLength() throws IOException {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
@@ -318,7 +320,7 @@
}
@Test
- public void testApplyLinks_unsupportedCharacter() {
+ public void testApplyLinks_unsupportedCharacter() throws IOException {
Spannable url = new SpannableString("\u202Emoc.diordna.com");
TextLinks.Request request = new TextLinks.Request.Builder(url).build();
assertEquals(
@@ -335,7 +337,7 @@
}
@Test
- public void testGenerateLinks_entityData() {
+ public void testGenerateLinks_entityData() throws IOException {
String text = "The number is +12122537077.";
Bundle extras = new Bundle();
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
@@ -352,7 +354,7 @@
}
@Test
- public void testGenerateLinks_entityData_disabled() {
+ public void testGenerateLinks_entityData_disabled() throws IOException {
String text = "The number is +12122537077.";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
@@ -365,7 +367,7 @@
}
@Test
- public void testDetectLanguage() {
+ public void testDetectLanguage() throws IOException {
String text = "This is English text";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(request);
@@ -373,7 +375,7 @@
}
@Test
- public void testDetectLanguage_japanese() {
+ public void testDetectLanguage_japanese() throws IOException {
String text = "これは日本語のテキストです";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(request);
@@ -381,7 +383,7 @@
}
@Test
- public void testSuggestConversationActions_textReplyOnly_maxOne() {
+ public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -405,7 +407,7 @@
}
@Test
- public void testSuggestConversationActions_textReplyOnly_noMax() {
+ public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -428,7 +430,7 @@
}
@Test
- public void testSuggestConversationActions_openUrl() {
+ public void testSuggestConversationActions_openUrl() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Check this out: https://www.android.com")
@@ -455,7 +457,7 @@
}
@Test
- public void testSuggestConversationActions_copy() {
+ public void testSuggestConversationActions_copy() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Authentication code: 12345")
@@ -483,7 +485,7 @@
}
@Test
- public void testSuggestConversationActions_deduplicate() {
+ public void testSuggestConversationActions_deduplicate() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("a@android.com b@android.com")
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
index c0a823e..e1e7982 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
@@ -22,7 +22,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import androidx.test.platform.app.InstrumentationRegistry;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
import java.util.function.Consumer;
import org.junit.After;
import org.junit.Before;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
new file mode 100644
index 0000000..ec1405b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import android.os.LocaleList;
+import org.junit.rules.ExternalResource;
+
+public class SetDefaultLocalesRule extends ExternalResource {
+
+ private LocaleList originalValue;
+
+ @Override
+ protected void before() throws Throwable {
+ super.before();
+ originalValue = LocaleList.getDefault();
+ }
+
+ public void set(LocaleList newValue) {
+ LocaleList.setDefault(newValue);
+ }
+
+ @Override
+ protected void after() {
+ super.after();
+ LocaleList.setDefault(originalValue);
+ }
+}