Export TextClassifier to sc-mainline-prod.
Use Merged-In to skip merging to master. Will resolve the conflict later.
Bug: 187927611
Merged-In: Ic08373aecd85a1c71814e7cc4798b54dc6281bdf
Change-Id: Icb51a375bea23d9767e0241bf50b1178f4382f0f
diff --git a/java/Android.bp b/java/Android.bp
index 30fd2bc..5948a17 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -57,6 +57,10 @@
"src/**/*.aidl",
],
manifest: "LibNoManifest_AndroidManifest.xml",
+ plugins: [
+ "auto_value_plugin",
+ "androidx.room_room-compiler-plugin",
+ ],
static_libs: [
"androidx.core_core",
"libtextclassifier-java",
@@ -69,6 +73,8 @@
"textclassifier-statsd",
"textclassifier-java-proto-lite",
"androidx.concurrent_concurrent-futures",
+ "auto_value_annotations",
+ "androidx.room_room-runtime",
],
sdk_version: "system_current",
min_sdk_version: "30",
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index beb155b..4838503 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -27,7 +27,7 @@
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.ConversationActions;
import android.view.textclassifier.ConversationActions.Message;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelFile;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 707af6a..49acf2e 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -30,7 +30,6 @@
import android.view.textclassifier.TextSelection;
import androidx.annotation.NonNull;
import androidx.collection.LruCache;
-import com.android.textclassifier.common.ModelFileManager;
import com.android.textclassifier.common.TextClassifierServiceExecutors;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
@@ -88,10 +87,10 @@
modelDownloadManager =
new com.android.textclassifier.downloader.ModelDownloadManager(
injector.getContext().getApplicationContext(),
- modelFileManager,
settings,
TextClassifierServiceExecutors.getDownloaderExecutor());
modelDownloadManager.onTextClassifierServiceCreated();
+ modelFileManager.addModelDownloaderModels(modelDownloadManager, settings);
textClassifierApiUsageLogger =
injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
}
@@ -209,6 +208,7 @@
IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
// TODO(licha): Also dump ModelDownloadManager for debugging
textClassifier.dump(indentingPrintWriter);
+ modelDownloadManager.dump(indentingPrintWriter);
dumpImpl(indentingPrintWriter);
indentingPrintWriter.flush();
}
diff --git a/java/src/com/android/textclassifier/common/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
similarity index 86%
rename from java/src/com/android/textclassifier/common/ModelFileManager.java
rename to java/src/com/android/textclassifier/ModelFileManager.java
index 406a889..155cf56 100644
--- a/java/src/com/android/textclassifier/common/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -14,19 +14,22 @@
* limitations under the License.
*/
-package com.android.textclassifier.common;
+package com.android.textclassifier;
import android.content.Context;
import android.content.res.AssetFileDescriptor;
import android.content.res.AssetManager;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
-import android.util.ArraySet;
import androidx.annotation.GuardedBy;
import androidx.collection.ArrayMap;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.downloader.ModelDownloadManager;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.android.textclassifier.ActionsSuggestionsModel;
import com.google.android.textclassifier.AnnotatorModel;
@@ -59,27 +62,19 @@
private static final String TAG = "ModelFileManager";
- private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
private static final String ASSETS_DIR = "textclassifier";
- private final List<ModelFileLister> modelFileListers;
- private final File modelDownloaderDir;
+ private ImmutableList<ModelFileLister> modelFileListers;
public ModelFileManager(Context context, TextClassifierSettings settings) {
Preconditions.checkNotNull(context);
Preconditions.checkNotNull(settings);
AssetManager assetManager = context.getAssets();
- this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
modelFileListers =
ImmutableList.of(
// Annotator models.
- new RegularFilePatternMatchLister(
- ModelType.ANNOTATOR,
- this.modelDownloaderDir,
- "annotator\\.(.*)\\.model",
- settings::isModelDownloadManagerEnabled),
new RegularFileFullMatchLister(
ModelType.ANNOTATOR,
new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
@@ -91,11 +86,6 @@
"annotator\\.(.*)\\.model",
/* isEnabled= */ () -> true),
// Actions models.
- new RegularFilePatternMatchLister(
- ModelType.ACTIONS_SUGGESTIONS,
- this.modelDownloaderDir,
- "actions_suggestions\\.(.*)\\.model",
- settings::isModelDownloadManagerEnabled),
new RegularFileFullMatchLister(
ModelType.ACTIONS_SUGGESTIONS,
new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
@@ -107,11 +97,6 @@
"actions_suggestions\\.(.*)\\.model",
/* isEnabled= */ () -> true),
// LangID models.
- new RegularFilePatternMatchLister(
- ModelType.LANG_ID,
- this.modelDownloaderDir,
- "lang_id\\.(.*)\\.model",
- settings::isModelDownloadManagerEnabled),
new RegularFileFullMatchLister(
ModelType.LANG_ID,
new File(CONFIG_UPDATER_DIR, "lang_id.model"),
@@ -126,10 +111,37 @@
@VisibleForTesting
public ModelFileManager(Context context, List<ModelFileLister> modelFileListers) {
- this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
this.modelFileListers = ImmutableList.copyOf(modelFileListers);
}
+ // TODO(licha): Move this to constructor and consider using DownloadedModelManager here
+ /** Enable ModelFileManager to scan and use models downloaded by model downloader. */
+ public void addModelDownloaderModels(
+ ModelDownloadManager modelDownloadManager, TextClassifierSettings settings) {
+ this.modelFileListers =
+ ImmutableList.<ModelFileLister>builder()
+ .addAll(modelFileListers)
+ .add(
+ modelType -> {
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ if (settings.isModelDownloadManagerEnabled()) {
+ for (File modelFile : modelDownloadManager.listDownloadedModels(modelType)) {
+ try {
+ // TODO(licha): Make the ModelFile class public and construct downloader
+ // model files with locale tag in our internal database
+ modelFilesBuilder.add(
+ ModelFile.createFromRegularFile(modelFile, modelType));
+ } catch (IOException e) {
+ TcLog.e(
+ TAG, "Failed to create ModelFile: " + modelFile.getAbsolutePath(), e);
+ }
+ }
+ }
+ return modelFilesBuilder.build();
+ })
+ .build();
+ }
+
/**
* Returns an immutable list of model files listed by the given model files supplier.
*
@@ -146,6 +158,7 @@
}
/** Lists model files. */
+ @FunctionalInterface
public interface ModelFileLister {
List<ModelFile> list(@ModelTypeDef String modelType);
}
@@ -328,16 +341,36 @@
@Nullable
public ModelFile findBestModelFile(
@ModelTypeDef String modelType, @Nullable LocaleList localePreferences) {
- final String languages =
- localePreferences == null || localePreferences.isEmpty()
- ? LocaleList.getDefault().toLanguageTags()
- : localePreferences.toLanguageTags();
- final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
+ Locale targetLocale =
+ localePreferences != null ? localePreferences.get(0) : Locale.getDefault();
+ return findBestModelFile(modelType, targetLocale);
+ }
+ /**
+ * Returns the best model file for the given locale, {@code null} if nothing is found.
+ *
+ * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
+ * @param targetLocale the preferred locale
+ */
+ @Nullable
+ private ModelFile findBestModelFile(@ModelTypeDef String modelType, Locale targetLocale) {
+ List<Locale.LanguageRange> deviceLanguageRanges =
+ Locale.LanguageRange.parse(LocaleList.getDefault().toLanguageTags());
+ boolean languageIndependentModelOnly = false;
+ if (Locale.lookupTag(deviceLanguageRanges, ImmutableList.of(targetLocale.getLanguage()))
+ == null) {
+ // If the targetLocale's language is not in device locale list, we don't match it to avoid
+ // leaking user language profile to the callers.
+ languageIndependentModelOnly = true;
+ }
+ List<Locale.LanguageRange> targetLanguageRanges =
+ Locale.LanguageRange.parse(targetLocale.toLanguageTag());
ModelFile bestModel = null;
for (ModelFile model : listModelFiles(modelType)) {
- // TODO(licha): update this when we want to support multiple languages
- if (model.isAnyLanguageSupported(languageRangeList)) {
+ if (languageIndependentModelOnly && !model.languageIndependent) {
+ continue;
+ }
+ if (model.isAnyLanguageSupported(targetLanguageRanges)) {
if (model.isPreferredTo(bestModel)) {
bestModel = model;
}
@@ -347,39 +380,6 @@
}
/**
- * Deletes model files that are not preferred for any locales in user's preference.
- *
- * <p>This method will be invoked as a clean-up after we download a new model successfully. Race
- * conditions are hard to avoid because we do not hold locks for files. But it should rarely cause
- * any issues since it's safe to delete a model file in use (b/c we mmap it to memory).
- */
- public void deleteUnusedModelFiles() {
- TcLog.d(TAG, "Start to delete unused model files.");
- LocaleList localeList = LocaleList.getDefault();
- for (@ModelTypeDef String modelType : ModelType.values()) {
- ArraySet<ModelFile> allModelFiles = new ArraySet<>(listModelFiles(modelType));
- for (int i = 0; i < localeList.size(); i++) {
- // If a model file is preferred for any local in locale list, then keep it
- ModelFile bestModel = findBestModelFile(modelType, new LocaleList(localeList.get(i)));
- allModelFiles.remove(bestModel);
- }
- for (ModelFile modelFile : allModelFiles) {
- if (modelFile.canWrite()) {
- TcLog.d(TAG, "Deleting model: " + modelFile);
- if (!modelFile.delete()) {
- TcLog.w(TAG, "Failed to delete model: " + modelFile);
- }
- }
- }
- }
- }
-
- /** Returns the directory containing models downloaded by the downloader. */
- public File getModelDownloaderDir() {
- return modelDownloaderDir;
- }
-
- /**
* Dumps the internal state for debugging.
*
* @param printWriter writer to write dumped states
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index bf326fb..3c28466 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -44,8 +44,7 @@
import androidx.annotation.GuardedBy;
import androidx.annotation.WorkerThread;
import androidx.core.util.Pair;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelFile;
import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
index fdf259e..b6ec1c2 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
@@ -20,6 +20,7 @@
import android.provider.DeviceConfig;
import android.provider.DeviceConfig.Properties;
+import android.text.TextUtils;
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
import androidx.annotation.NonNull;
@@ -117,13 +118,19 @@
"manifest_download_required_network_type";
/** Max attempts allowed for a single ModelDownloader downloading task. */
@VisibleForTesting
- static final String MODEL_DOWNLOAD_MAX_ATTEMPTS = "model_download_max_attempts";
+ static final String MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS = "model_download_worker_max_attempts";
+ /** Max attempts allowed for a certain manifest url. */
+ @VisibleForTesting
+ static final String MANIFEST_DOWNLOAD_MAX_ATTEMPTS = "manifest_download_max_attempts";
@VisibleForTesting
static final String MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS =
"model_download_backoff_delay_in_millis";
/** Flag name for manifest url is dynamically formatted based on model type and model language. */
@VisibleForTesting public static final String MANIFEST_URL_TEMPLATE = "manifest_url_%s_%s";
+
+ @VisibleForTesting public static final String MODEL_URL_BLOCKLIST = "model_url_blocklist";
+ @VisibleForTesting public static final String MODEL_URL_BLOCKLIST_SEPARATOR = ",";
/** Sampling rate for TextClassifier API logging. */
static final String TEXTCLASSIFIER_API_LOG_SAMPLE_RATE = "textclassifier_api_log_sample_rate";
@@ -195,7 +202,8 @@
private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "UNMETERED";
- private static final int MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 5;
+ private static final int MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS_DEFAULT = 5;
+ private static final int MANIFEST_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 2;
private static final long MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS_DEFAULT = HOURS.toMillis(1);
private static final String MANIFEST_URL_DEFAULT = "";
private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
@@ -380,9 +388,14 @@
MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT);
}
- public int getModelDownloadMaxAttempts() {
+ public int getModelDownloadWorkerMaxAttempts() {
return deviceConfig.getInt(
- NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
+ NAMESPACE, MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS, MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS_DEFAULT);
+ }
+
+ public int getManifestDownloadMaxAttempts() {
+ return deviceConfig.getInt(
+ NAMESPACE, MANIFEST_DOWNLOAD_MAX_ATTEMPTS, MANIFEST_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
}
public long getModelDownloadBackoffDelayInMillis() {
@@ -406,12 +419,22 @@
return deviceConfig.getString(NAMESPACE, urlFlagName, MANIFEST_URL_DEFAULT);
}
+ /* Gets a list of models urls that should not be used. Usually used for a quick rollback. */
+ public ImmutableList<String> getModelUrlBlocklist() {
+ return ImmutableList.copyOf(
+ Splitter.on(MODEL_URL_BLOCKLIST_SEPARATOR)
+ .split(deviceConfig.getString(NAMESPACE, MODEL_URL_BLOCKLIST, "")));
+ }
+
+ // TODO(licha): Let this method return a <localeTag, flagValue> map.
/**
* Gets all language variants configured for a specific ModelType.
*
* <p>For a specific language, there can be many variants: de-CH, de-LI, zh-Hans, zh-Hant. There
* is no easy way to hardcode the list in client. Therefore, we parse all configured flag's name
* in DeviceConfig, and let the client to choose the best variant to download.
+ *
+ * <p>If one flag's value is empty, it will be ignored.
*/
public ImmutableList<String> getLanguageTagsForManifestURL(
@ModelType.ModelTypeDef String modelType) {
@@ -419,8 +442,11 @@
Properties properties = deviceConfig.getProperties(NAMESPACE);
ImmutableList.Builder<String> variantsBuilder = ImmutableList.builder();
for (String name : properties.getKeyset()) {
- if (name.startsWith(urlFlagBaseName)
- && properties.getString(name, /* defaultValue= */ null) != null) {
+ if (!name.startsWith(urlFlagBaseName)) {
+ continue;
+ }
+ String value = properties.getString(name, /* defaultValue= */ null);
+ if (!TextUtils.isEmpty(value)) {
variantsBuilder.add(name.substring(urlFlagBaseName.length()));
}
}
@@ -458,7 +484,8 @@
pw.printPair(TEMPLATE_INTENT_FACTORY_ENABLED, isTemplateIntentFactoryEnabled());
pw.printPair(TRANSLATE_IN_CLASSIFICATION_ENABLED, isTranslateInClassificationEnabled());
pw.printPair(MODEL_DOWNLOAD_MANAGER_ENABLED, isModelDownloadManagerEnabled());
- pw.printPair(MODEL_DOWNLOAD_MAX_ATTEMPTS, getModelDownloadMaxAttempts());
+ pw.printPair(MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS, getModelDownloadWorkerMaxAttempts());
+ pw.printPair(MANIFEST_DOWNLOAD_MAX_ATTEMPTS, getManifestDownloadMaxAttempts());
pw.decreaseIndent();
pw.printPair(TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, getTextClassifierApiLogSampleRate());
pw.printPair(SESSION_ID_TO_CONTEXT_CACHE_SIZE, getSessionIdToContextCacheSize());
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java b/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java
new file mode 100644
index 0000000..0614c2f
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java
@@ -0,0 +1,381 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import androidx.annotation.IntDef;
+import androidx.annotation.NonNull;
+import androidx.room.ColumnInfo;
+import androidx.room.Dao;
+import androidx.room.Database;
+import androidx.room.DatabaseView;
+import androidx.room.Delete;
+import androidx.room.Embedded;
+import androidx.room.Entity;
+import androidx.room.ForeignKey;
+import androidx.room.Index;
+import androidx.room.Insert;
+import androidx.room.OnConflictStrategy;
+import androidx.room.Query;
+import androidx.room.RoomDatabase;
+import androidx.room.Transaction;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.Iterables;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+/** Database storing info about models downloaded by model downloader */
+@Database(
+ entities = {
+ DownloadedModelDatabase.Model.class,
+ DownloadedModelDatabase.Manifest.class,
+ DownloadedModelDatabase.ManifestModelCrossRef.class,
+ DownloadedModelDatabase.ManifestEnrollment.class
+ },
+ views = {DownloadedModelDatabase.ModelView.class},
+ version = 1,
+ exportSchema = true)
+abstract class DownloadedModelDatabase extends RoomDatabase {
+ public static final String TAG = "DownloadedModelDatabase";
+
+ /** Rpresents a downloaded model file. */
+ @AutoValue
+ @Entity(
+ tableName = "model",
+ primaryKeys = {"model_url"})
+ abstract static class Model {
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_url")
+ @NonNull
+ public abstract String getModelUrl();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_path")
+ @NonNull
+ public abstract String getModelPath();
+
+ public static Model create(String modelUrl, String modelPath) {
+ return new AutoValue_DownloadedModelDatabase_Model(modelUrl, modelPath);
+ }
+ }
+
+ /** Rpresents a manifest we processed. */
+ @AutoValue
+ @Entity(
+ tableName = "manifest",
+ primaryKeys = {"manifest_url"})
+ abstract static class Manifest {
+ // TODO(licha): Consider using Enum here
+ @Retention(RetentionPolicy.SOURCE)
+ @IntDef({STATUS_UNKNOWN, STATUS_FAILED, STATUS_SUCCEEDED})
+ @interface StatusDef {}
+
+ public static final int STATUS_UNKNOWN = 0;
+ /** Failed to download this manifest. Could be retried in the future. */
+ public static final int STATUS_FAILED = 1;
+ /** Downloaded this manifest successfully and it's currently in storage. */
+ public static final int STATUS_SUCCEEDED = 2;
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "manifest_url")
+ @NonNull
+ public abstract String getManifestUrl();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "status")
+ @StatusDef
+ public abstract int getStatus();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "failure_counts")
+ public abstract int getFailureCounts();
+
+ public static Manifest create(String manifestUrl, @StatusDef int status, int failureCounts) {
+ return new AutoValue_DownloadedModelDatabase_Manifest(manifestUrl, status, failureCounts);
+ }
+ }
+
+ /**
+ * Represents the relationship between manfiests and downloaded models.
+ *
+ * <p>A manifest can include multiple models, a model can also be included in multiple manifests.
+ * In different manifests, a model may have different configurations (e.g. primary model in
+ * manfiest A but dark model in B).
+ */
+ @AutoValue
+ @Entity(
+ tableName = "manifest_model_cross_ref",
+ primaryKeys = {"manifest_url", "model_url"},
+ foreignKeys = {
+ @ForeignKey(
+ entity = Manifest.class,
+ parentColumns = "manifest_url",
+ childColumns = "manifest_url",
+ onDelete = ForeignKey.CASCADE),
+ @ForeignKey(
+ entity = Model.class,
+ parentColumns = "model_url",
+ childColumns = "model_url",
+ onDelete = ForeignKey.CASCADE),
+ },
+ indices = {
+ @Index(value = {"manifest_url"}),
+ @Index(value = {"model_url"}),
+ })
+ abstract static class ManifestModelCrossRef {
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "manifest_url")
+ @NonNull
+ public abstract String getManifestUrl();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_url")
+ @NonNull
+ public abstract String getModelUrl();
+
+ public static ManifestModelCrossRef create(String manifestUrl, String modelUrl) {
+ return new AutoValue_DownloadedModelDatabase_ManifestModelCrossRef(manifestUrl, modelUrl);
+ }
+ }
+
+ /**
+ * Represents the relationship between user scenarios and manifests.
+ *
+ * <p>For each unique user scenario (i.e. modelType + localTag), we store the manifest we should
+ * use. The same manifest can be used for different scenarios.
+ */
+ @AutoValue
+ @Entity(
+ tableName = "manifest_enrollment",
+ primaryKeys = {"model_type", "locale_tag"},
+ foreignKeys = {
+ @ForeignKey(
+ entity = Manifest.class,
+ parentColumns = "manifest_url",
+ childColumns = "manifest_url",
+ onDelete = ForeignKey.CASCADE)
+ },
+ indices = {@Index(value = {"manifest_url"})})
+ abstract static class ManifestEnrollment {
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "model_type")
+ @NonNull
+ @ModelTypeDef
+ public abstract String getModelType();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "locale_tag")
+ @NonNull
+ public abstract String getLocaleTag();
+
+ @AutoValue.CopyAnnotations
+ @ColumnInfo(name = "manifest_url")
+ @NonNull
+ public abstract String getManifestUrl();
+
+ public static ManifestEnrollment create(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ return new AutoValue_DownloadedModelDatabase_ManifestEnrollment(
+ modelType, localeTag, manifestUrl);
+ }
+ }
+
+ /** Represents the mapping from manfiest enrollments to models. */
+ @AutoValue
+ @DatabaseView(
+ value =
+ "SELECT manifest_enrollment.*, model.* "
+ + "FROM manifest_enrollment "
+ + "INNER JOIN manifest_model_cross_ref "
+ + "ON manifest_enrollment.manifest_url = manifest_model_cross_ref.manifest_url "
+ + "INNER JOIN model "
+ + "ON manifest_model_cross_ref.model_url = model.model_url",
+ viewName = "model_view")
+ abstract static class ModelView {
+ @AutoValue.CopyAnnotations
+ @Embedded
+ @NonNull
+ public abstract ManifestEnrollment getManifestEnrollment();
+
+ @AutoValue.CopyAnnotations
+ @Embedded
+ @NonNull
+ public abstract Model getModel();
+
+ public static ModelView create(ManifestEnrollment manifestEnrollment, Model model) {
+ return new AutoValue_DownloadedModelDatabase_ModelView(manifestEnrollment, model);
+ }
+ }
+
+ @Dao
+ abstract static class DownloadedModelDatabaseDao {
+ // Full table scan
+ @Query("SELECT * FROM model")
+ abstract List<Model> queryAllModels();
+
+ @Query("SELECT * FROM manifest")
+ abstract List<Manifest> queryAllManifests();
+
+ @Query("SELECT * FROM manifest_model_cross_ref")
+ abstract List<ManifestModelCrossRef> queryAllManifestModelCrossRefs();
+
+ @Query("SELECT * FROM manifest_enrollment")
+ abstract List<ManifestEnrollment> queryAllManifestEnrollments();
+
+ @Query("SELECT * FROM model_view")
+ abstract List<ModelView> queryAllModelViews();
+
+ // Single table query
+ @Query("SELECT * FROM model WHERE model_url = :modelUrl")
+ abstract List<Model> queryModelWithModelUrl(String modelUrl);
+
+ @Query("SELECT * FROM manifest WHERE manifest_url = :manifestUrl")
+ abstract List<Manifest> queryManifestWithManifestUrl(String manifestUrl);
+
+ @Query(
+ "SELECT * FROM manifest_enrollment WHERE model_type = :modelType "
+ + "AND locale_tag = :localeTag")
+ abstract List<ManifestEnrollment> queryManifestEnrollmentWithModelTypeAndLocaleTag(
+ @ModelTypeDef String modelType, String localeTag);
+
+ // Helpers for clean up
+ @Query(
+ "SELECT manifest.* FROM manifest "
+ + "LEFT JOIN model_view "
+ + "ON manifest.manifest_url = model_view.manifest_url "
+ + "WHERE model_view.manifest_url IS NULL "
+ + "AND manifest.status = 2")
+ abstract List<Manifest> queryUnusedManifests();
+
+ @Query(
+ "SELECT * FROM manifest WHERE manifest.status = 1 "
+ + "AND manifest.manifest_url NOT IN (:manifestUrlsToKeep)")
+ abstract List<Manifest> queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep);
+
+ @Query(
+ "SELECT model.* FROM model LEFT JOIN model_view "
+ + "ON model.model_url = model_view.model_url "
+ + "WHERE model_view.model_url IS NULL")
+ abstract List<Model> queryUnusedModels();
+
+ // Insertion
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(Model model);
+
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(Manifest manifest);
+
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(ManifestModelCrossRef manifestModelCrossRef);
+
+ @Insert(onConflict = OnConflictStrategy.REPLACE)
+ abstract void insert(ManifestEnrollment manifestEnrollment);
+
+ @Transaction
+ void insertManifestAndModelCrossRef(String manifestUrl, String modelUrl) {
+ insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
+ insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
+ }
+
+ @Transaction
+ void increaseManifestFailureCounts(String manifestUrl) {
+ List<Manifest> manifests = queryManifestWithManifestUrl(manifestUrl);
+ if (manifests.isEmpty()) {
+ insert(Manifest.create(manifestUrl, Manifest.STATUS_FAILED, /* failureCounts= */ 1));
+ } else {
+ Manifest prevManifest = Iterables.getOnlyElement(manifests);
+ insert(
+ Manifest.create(
+ manifestUrl, Manifest.STATUS_FAILED, prevManifest.getFailureCounts() + 1));
+ }
+ }
+
+ // Deletion
+ @Delete
+ abstract void deleteModels(List<Model> models);
+
+ @Delete
+ abstract void deleteManifests(List<Manifest> manifests);
+
+ @Delete
+ abstract void deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs);
+
+ @Delete
+ abstract void deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments);
+
+ @Transaction
+ void deleteUnusedManifestsAndModels() {
+ // Because Manifest table is the parent table of ManifestModelCrossRef table, related cross
+ // ref row in that table will be deleted automatically
+ deleteManifests(queryUnusedManifests());
+ deleteModels(queryUnusedModels());
+ }
+
+ @Transaction
+ void deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep) {
+ deleteManifests(queryUnusedManifestFailureRecords(manifestUrlsToKeep));
+ }
+ }
+
+ abstract DownloadedModelDatabaseDao dao();
+
+ /** Dump the database for debugging. */
+ void dump(IndentingPrintWriter printWriter, ExecutorService executorService) {
+ printWriter.println("DownloadedModelDatabase");
+ printWriter.increaseIndent();
+ try {
+ printWriter.println("Model Table:");
+ printWriter.increaseIndent();
+ List<Model> models = executorService.submit(() -> dao().queryAllModels()).get();
+ for (Model model : models) {
+ printWriter.println(model.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("Manifest Table:");
+ printWriter.increaseIndent();
+ List<Manifest> manifests = executorService.submit(() -> dao().queryAllManifests()).get();
+ for (Manifest manifest : manifests) {
+ printWriter.println(manifest.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("ManifestModelCrossRef Table:");
+ printWriter.increaseIndent();
+ List<ManifestModelCrossRef> manifestModelCrossRefs =
+ executorService.submit(() -> dao().queryAllManifestModelCrossRefs()).get();
+ for (ManifestModelCrossRef manifestModelCrossRef : manifestModelCrossRefs) {
+ printWriter.println(manifestModelCrossRef.toString());
+ }
+ printWriter.decreaseIndent();
+ printWriter.println("ManifestEnrollment Table:");
+ printWriter.increaseIndent();
+ List<ManifestEnrollment> manifestEnrollments =
+ executorService.submit(() -> dao().queryAllManifestEnrollments()).get();
+ for (ManifestEnrollment manifestEnrollment : manifestEnrollments) {
+ printWriter.println(manifestEnrollment.toString());
+ }
+ printWriter.decreaseIndent();
+ } catch (ExecutionException | InterruptedException e) {
+ TcLog.e(TAG, "Failed to dump the database", e);
+ }
+ printWriter.decreaseIndent();
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java
new file mode 100644
index 0000000..f6400fb
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import java.io.File;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
+
+// TODO(licha): Let Worker access DB class directly, then we can make this a lister interface
+/** An interface to provide easy access to DownloadedModelDatabase. */
+interface DownloadedModelManager {
+
+ /** Returns the directory containing models downloaded by the downloader. */
+ File getModelDownloaderDir();
+
+ /**
+ * Returns all downloaded model files for the given modelType
+ *
+ * <p>This method should return quickly as it may be on the critical path of serving requests.
+ *
+ * @param modelType the type of the model
+ * @return the model files. Empty if no suitable model found
+ */
+ @Nullable
+ List<File> listModels(@ModelTypeDef String modelType);
+
+ /**
+ * Returns the model entry if the model represented by the url is in our database.
+ *
+ * @param modelUrl the model url
+ * @return model entry from internal database, null if not exist
+ */
+ @Nullable
+ Model getModel(String modelUrl);
+
+ /**
+ * Returns the manifest entry if the manifest represented by the url is in our database.
+ *
+ * @param manifestUrl the manifest url
+ * @return manifest entry from internal database, null if not exist
+ */
+ @Nullable
+ Manifest getManifest(String manifestUrl);
+
+ /**
+ * Returns the manifest enrollment entry if a manifest is registered for the given type and
+ * locale.
+ *
+ * @param modelType the model type of the enrollment
+ * @param localeTag the locale tag of the enrollment
+ * @return manifest enrollment entry from internal database, null if not exist
+ */
+ @Nullable
+ ManifestEnrollment getManifestEnrollment(@ModelTypeDef String modelType, String localeTag);
+
+ /**
+ * Add a newly downloaded model to the internal database.
+ *
+ * <p>The model must be linked to a manifest via #registerManifest(). Otherwise it will be cleaned
+ * up automatically later.
+ *
+ * @param modelUrl the url where we downloaded model from
+ * @param modelPath the path where we store the downloaded model
+ */
+ void registerModel(String modelUrl, String modelPath);
+
+ /**
+ * Add a newly downloaded manifest to the internal database.
+ *
+ * <p>The manifest must be linked to a specific use case via #registerManifestEnrollment().
+ * Otherwise it will be cleaned up automatically later. Currently there is only one model in one
+ * manifest.
+ *
+ * @param manifestUrl the url where we downloaded manifest
+ * @param modelUrl the url where we downloaded the only model inside the manifest
+ */
+ void registerManifest(String manifestUrl, String modelUrl);
+
+ /**
+ * Add a failure records for the given manifest url.
+ *
+ * <p>If the manifest failed before, then increase the prevFailureCounts by one. We skip manifest
+ * if it failed too many times before.
+ *
+ * @param manifestUrl the failed manifest url
+ */
+ void registerManifestDownloadFailure(String manifestUrl);
+
+ /**
+ * Link a manifest to a specific (modelType, localeTag) use case.
+ *
+ * <p>After this registration, we will start to use this model file for all requests for the given
+ * locale and the specified model type.
+ *
+ * @param modelType the model type
+ * @param localeTag the tag of the locale on user's device that this manifest should be used for
+ * @param manifestUrl the url of the manifest
+ */
+ void registerManifestEnrollment(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl);
+
+ /**
+ * Clean up unused downloaded models and update other internal states.
+ *
+ * @param bestLocaleTagAndManifestUrls <modelType, <localeTag, manifestUrl>> the worker tried to
+ * download
+ */
+ void onDownloadCompleted(Map<String, Pair<String, String>> bestLocaleTagAndManifestUrls);
+
+ /**
+ * Dumps the internal state for debugging.
+ *
+ * @param printWriter writer to write dumped states
+ */
+ void dump(IndentingPrintWriter printWriter);
+}
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
new file mode 100644
index 0000000..78559fd
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
@@ -0,0 +1,279 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.content.Context;
+import android.util.ArrayMap;
+import android.util.Pair;
+import androidx.annotation.GuardedBy;
+import androidx.room.Room;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierServiceExecutors;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/** A singleton implementation of DownloadedModelManager. */
+public final class DownloadedModelManagerImpl implements DownloadedModelManager {
+ private static final String TAG = "DownloadedModelManagerImpl";
+ private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models";
+ private static final String DOWNLOADED_MODEL_DATABASE_NAME = "tcs-downloaded-model-db";
+
+ private static final Object staticLock = new Object();
+
+ @GuardedBy("staticLock")
+ private static DownloadedModelManagerImpl instance;
+
+ private final File modelDownloaderDir;
+ private final DownloadedModelDatabase db;
+ private final TextClassifierSettings settings;
+
+ private final Object cacheLock = new Object();
+
+ // modeltype -> downloaded model files
+ @GuardedBy("cacheLock")
+ private final ArrayMap<String, List<Model>> modelLookupCache;
+
+ @GuardedBy("cacheLock")
+ private boolean cacheInitialized;
+
+ @Nullable
+ public static DownloadedModelManager getInstance(Context context) {
+ synchronized (staticLock) {
+ if (instance == null) {
+ DownloadedModelDatabase db =
+ Room.databaseBuilder(
+ context, DownloadedModelDatabase.class, DOWNLOADED_MODEL_DATABASE_NAME)
+ .build();
+ File modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ instance =
+ new DownloadedModelManagerImpl(db, modelDownloaderDir, new TextClassifierSettings());
+ }
+ return instance;
+ }
+ }
+
+ @VisibleForTesting
+ static DownloadedModelManagerImpl getInstanceForTesting(
+ DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
+ return new DownloadedModelManagerImpl(db, modelDownloaderDir, settings);
+ }
+
+ private DownloadedModelManagerImpl(
+ DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
+ this.db = db;
+ this.modelDownloaderDir = modelDownloaderDir;
+ this.modelLookupCache = new ArrayMap<>();
+ for (String modelType : ModelType.values()) {
+ this.modelLookupCache.put(modelType, new ArrayList<>());
+ }
+ this.settings = settings;
+ this.cacheInitialized = false;
+ }
+
+ @Override
+ public File getModelDownloaderDir() {
+ if (!modelDownloaderDir.exists()) {
+ modelDownloaderDir.mkdirs();
+ }
+ return modelDownloaderDir;
+ }
+
+ @Override
+ @Nullable
+ public ImmutableList<File> listModels(@ModelTypeDef String modelType) {
+ synchronized (cacheLock) {
+ if (!cacheInitialized) {
+ updateCache();
+ }
+ ImmutableList.Builder<File> builder = ImmutableList.builder();
+ ImmutableList<String> blockedModels = settings.getModelUrlBlocklist();
+ for (Model model : modelLookupCache.get(modelType)) {
+ if (blockedModels.contains(model.getModelUrl())) {
+ TcLog.d(TAG, "Model is blocklisted: " + model);
+ continue;
+ }
+ builder.add(new File(model.getModelPath()));
+ }
+ return builder.build();
+ }
+ }
+
+ @Override
+ @Nullable
+ public Model getModel(String modelUrl) {
+ List<Model> models = db.dao().queryModelWithModelUrl(modelUrl);
+ return Iterables.getFirst(models, null);
+ }
+
+ @Override
+ @Nullable
+ public Manifest getManifest(String manifestUrl) {
+ List<Manifest> manifests = db.dao().queryManifestWithManifestUrl(manifestUrl);
+ return Iterables.getFirst(manifests, null);
+ }
+
+ @Override
+ @Nullable
+ public ManifestEnrollment getManifestEnrollment(
+ @ModelTypeDef String modelType, String localeTag) {
+ List<ManifestEnrollment> manifestEnrollments =
+ db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(modelType, localeTag);
+ return Iterables.getFirst(manifestEnrollments, null);
+ }
+
+ @Override
+ public void registerModel(String modelUrl, String modelPath) {
+ db.dao().insert(Model.create(modelUrl, modelPath));
+ }
+
+ @Override
+ public void registerManifest(String manifestUrl, String modelUrl) {
+ db.dao().insertManifestAndModelCrossRef(manifestUrl, modelUrl);
+ }
+
+ @Override
+ public void registerManifestDownloadFailure(String manifestUrl) {
+ db.dao().increaseManifestFailureCounts(manifestUrl);
+ }
+
+ @Override
+ public void registerManifestEnrollment(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
+ }
+
+ @Override
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("DownloadedModelManagerImpl:");
+ printWriter.increaseIndent();
+ db.dump(printWriter, TextClassifierServiceExecutors.getDownloaderExecutor());
+ printWriter.println("ModelLookupCache:");
+ synchronized (cacheLock) {
+ for (Map.Entry<String, List<Model>> entry : modelLookupCache.entrySet()) {
+ printWriter.println(entry.getKey());
+ printWriter.increaseIndent();
+ for (Model model : entry.getValue()) {
+ printWriter.println(model.toString());
+ }
+ printWriter.decreaseIndent();
+ }
+ }
+ printWriter.decreaseIndent();
+ }
+
+ @Override
+ public void onDownloadCompleted(
+ Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls) {
+ TcLog.v(TAG, "Start to clean up models and update model lookup cache...");
+ // Step 1: Clean up ManifestEnrollment table
+ List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
+ List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
+ for (String modelType : ModelType.values()) {
+ List<ManifestEnrollment> manifestEnrollmentsByType =
+ allManifestEnrollments.stream()
+ .filter(modelEnrollment -> modelEnrollment.getModelType().equals(modelType))
+ .collect(Collectors.toList());
+ Pair<String, String> localeTagAndManifestUrl =
+ modelTypeToLocaleTagAndManifestUrls.get(modelType);
+ if (localeTagAndManifestUrl == null) {
+ // No suitable manifest configured for this model type. Delete everything.
+ manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType);
+ } else {
+ String localeTag = localeTagAndManifestUrl.first;
+ String manifestUrl = localeTagAndManifestUrl.second;
+ Optional<ManifestEnrollment> optionalManifestEnrollment =
+ manifestEnrollmentsByType.stream()
+ .filter(
+ manifestEnrollment ->
+ manifestEnrollment.getLocaleTag().equals(localeTag)
+ && manifestEnrollment.getManifestUrl().equals(manifestUrl))
+ .findAny();
+ if (optionalManifestEnrollment.isPresent()) {
+ // The desired manifest is downloaded successfully. Delete everything else.
+ manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType);
+ manifestEnrollmentsToDelete.remove(optionalManifestEnrollment.get());
+ } else {
+ // The desired manifest failed to be downloaded. Do not delete anything.
+ TcLog.w(
+ TAG,
+ String.format(
+ "Desired manifest is missing on download completed: %s, %s, %s",
+ modelType, localeTag, manifestUrl));
+ }
+ }
+ }
+ db.dao().deleteManifestEnrollments(manifestEnrollmentsToDelete);
+ // Step 2: Clean up Manifests and Models that are not linked to any ManifestEnrollment
+ db.dao().deleteUnusedManifestsAndModels();
+ // Step 3: Clean up Manifest failure records
+ // We only keep a failure record if the worker stills trys to download it
+ // We restrict the deletion to failure records only because although some manifest urls are not
+ // in allAttemptedManifestUrls, they can still be useful (e.g. current manifest is v901, and we
+ // failed to download v902. v901 will not be in the map, but it should be kept.)
+ List<String> allAttemptedManifestUrls =
+ modelTypeToLocaleTagAndManifestUrls.values().stream()
+ .map(localTagAndManifestUrlPair -> localTagAndManifestUrlPair.second)
+ .collect(Collectors.toList());
+ db.dao().deleteUnusedManifestFailureRecords(allAttemptedManifestUrls);
+ // Step 4: Update lookup cache
+ updateCache();
+ // Step 5: Clean up unused model files.
+ Set<String> modelPathsToKeep =
+ db.dao().queryAllModels().stream().map(Model::getModelPath).collect(Collectors.toSet());
+ for (File modelFile : getModelDownloaderDir().listFiles()) {
+ if (!modelPathsToKeep.contains(modelFile.getAbsolutePath())) {
+ TcLog.d(TAG, "Delete model file: " + modelFile.getAbsolutePath());
+ if (!modelFile.delete()) {
+ TcLog.e(TAG, "Failed to delete model file: " + modelFile.getAbsolutePath());
+ }
+ }
+ }
+ }
+
+ // Clear the cache table and rebuild the cache based on ModelView table
+ private void updateCache() {
+ synchronized (cacheLock) {
+ TcLog.v(TAG, "Updating model lookup cache...");
+ for (String modelType : ModelType.values()) {
+ modelLookupCache.get(modelType).clear();
+ }
+ for (ModelView modelView : db.dao().queryAllModelViews()) {
+ modelLookupCache
+ .get(modelView.getManifestEnrollment().getModelType())
+ .add(modelView.getModel());
+ }
+ cacheInitialized = true;
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/downloader/LocaleUtils.java b/java/src/com/android/textclassifier/downloader/LocaleUtils.java
new file mode 100644
index 0000000..a6791a3
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/LocaleUtils.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.text.TextUtils;
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.google.common.annotations.VisibleForTesting;
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import javax.annotation.Nullable;
+
+/** Utilities for locale matching. */
+final class LocaleUtils {
+ @VisibleForTesting static final String UNIVERSAL_LOCALE_TAG = "universal";
+
+ /**
+ * Find the best locale tag as well as the configured manfiest url from device config.
+ *
+ * @param modelType the model type
+ * @param targetLocale target locale
+ * @param settings TextClassifierSettings to check device config
+ * @return a pair of <bestLocaleTag, manfiestUrl>. Null if not found.
+ */
+ @Nullable
+ static Pair<String, String> lookupBestLocaleTagAndManifestUrl(
+ @ModelTypeDef String modelType, Locale targetLocale, TextClassifierSettings settings) {
+ List<String> allLocaleTags = settings.getLanguageTagsForManifestURL(modelType);
+ String bestLocaleTag = lookupBestLocaleTag(targetLocale, allLocaleTags);
+ if (bestLocaleTag == null) {
+ return null;
+ }
+ String manifestUrl = settings.getManifestURL(modelType, bestLocaleTag);
+ if (TextUtils.isEmpty(manifestUrl)) {
+ return null;
+ }
+ return Pair.create(bestLocaleTag, manifestUrl);
+ }
+
+ /** Find the best locale tag for the target locale. Return null if no one is suitable. */
+ @Nullable
+ static String lookupBestLocaleTag(Locale targetLocale, Collection<String> availableTags) {
+ // Notice: this lookup API just trys to match the longest prefix for the target locale tag.
+ // Its implementation looks inefficient and the behavior may not be 100% desired. E.g. if the
+ // target locale is en, and we only have en-uk in available tags, the current API returns null.
+ String bestTag =
+ Locale.lookupTag(Locale.LanguageRange.parse(targetLocale.toLanguageTag()), availableTags);
+ if (bestTag != null) {
+ return bestTag;
+ }
+ if (availableTags.contains(UNIVERSAL_LOCALE_TAG)) {
+ return UNIVERSAL_LOCALE_TAG;
+ }
+ return null;
+ }
+
+ private LocaleUtils() {}
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
index 2629c20..78d26fa 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
@@ -22,50 +22,39 @@
import android.content.Context;
import android.content.Intent;
import android.content.IntentFilter;
-import android.os.LocaleList;
import android.provider.DeviceConfig;
-import android.text.TextUtils;
import androidx.work.BackoffPolicy;
import androidx.work.Constraints;
import androidx.work.ExistingWorkPolicy;
import androidx.work.ListenableWorker;
import androidx.work.NetworkType;
import androidx.work.OneTimeWorkRequest;
-import androidx.work.WorkInfo;
+import androidx.work.Operation;
import androidx.work.WorkManager;
-import androidx.work.WorkQuery;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Enums;
import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
-import java.time.Duration;
+import java.io.File;
import java.util.List;
-import java.util.Locale;
-import java.util.concurrent.ExecutionException;
+import javax.annotation.Nullable;
/** Manager to listen to config update and download latest models. */
public final class ModelDownloadManager {
private static final String TAG = "ModelDownloadManager";
- static final String UNIVERSAL_LOCALE_TAG = "universal";
-
- public static final int CHECK_REASON_TCS_ON_CREATE = 0;
- public static final int CHECK_REASON_DEVICE_CONFIG_UPDATED = 1;
- public static final int CHECK_REASON_LOCALE_UPDATED = 2;
-
- // Keep the records forever (100 Years). We use the record to skip previously failed tasks.
- static final long DAYS_TO_KEEP_THE_DOWNLOAD_RESULT = 365000L;
-
- private final Object lock = new Object();
+ @VisibleForTesting static final String UNIQUE_QUEUE_NAME = "ModelDownloadWorkManagerQueue";
private final Context appContext;
private final Class<? extends ListenableWorker> modelDownloadWorkerClass;
- private final ModelFileManager modelFileManager;
+ private final DownloadedModelManager downloadedModelManager;
private final TextClassifierSettings settings;
private final ListeningExecutorService executorService;
private final DeviceConfig.OnPropertiesChangedListener deviceConfigListener;
@@ -75,28 +64,31 @@
* Constructor for ModelDownloadManager.
*
* @param appContext the context of this application
- * @param modelFileManager ModelFileManager to interact with storage layer
* @param settings TextClassifierSettings to access DeviceConfig and other settings
* @param executorService background executor service
*/
public ModelDownloadManager(
Context appContext,
- ModelFileManager modelFileManager,
TextClassifierSettings settings,
ListeningExecutorService executorService) {
- this(appContext, NewModelDownloadWorker.class, modelFileManager, settings, executorService);
+ this(
+ appContext,
+ NewModelDownloadWorker.class,
+ DownloadedModelManagerImpl.getInstance(appContext),
+ settings,
+ executorService);
}
@VisibleForTesting
ModelDownloadManager(
Context appContext,
Class<? extends ListenableWorker> modelDownloadWorkerClass,
- ModelFileManager modelFileManager,
+ DownloadedModelManager downloadedModelManager,
TextClassifierSettings settings,
ListeningExecutorService executorService) {
this.appContext = Preconditions.checkNotNull(appContext);
this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass);
- this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
+ this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager);
this.settings = Preconditions.checkNotNull(settings);
this.executorService = Preconditions.checkNotNull(executorService);
@@ -104,8 +96,6 @@
new DeviceConfig.OnPropertiesChangedListener() {
@Override
public void onPropertiesChanged(DeviceConfig.Properties unused) {
- // We will only be notified for changes in our package. Trigger the check even when the
- // change is unrelated just in case we missed a previous update.
onTextClassifierDeviceConfigChanged();
}
};
@@ -118,6 +108,12 @@
};
}
+ /** Returns the downlaoded models for the given modelType. */
+ @Nullable
+ public List<File> listDownloadedModels(@ModelTypeDef String modelType) {
+ return downloadedModelManager.listModels(modelType);
+ }
+
/** Notifies the model downlaoder that the text classifier service is created. */
public void onTextClassifierServiceCreated() {
DeviceConfig.addOnPropertiesChangedListener(
@@ -128,12 +124,8 @@
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- executorService.execute(
- () -> {
- TcLog.d(TAG, "Checking downloader flags because of the start of TextClassifierService.");
- modelFileManager.deleteUnusedModelFiles();
- checkConfigAndScheduleDownloads();
- });
+ TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started.");
+ scheduleDownloadWork();
}
// TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -143,12 +135,8 @@
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- executorService.execute(
- () -> {
- TcLog.d(TAG, "Checking downloader flags because of locale changes.");
- modelFileManager.deleteUnusedModelFiles();
- checkConfigAndScheduleDownloads();
- });
+ TcLog.v(TAG, "Try to schedule model download work because of system locale changes.");
+ scheduleDownloadWork();
}
// TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -158,11 +146,8 @@
if (!settings.isModelDownloadManagerEnabled()) {
return;
}
- executorService.execute(
- () -> {
- TcLog.d(TAG, "Checking downloader flags because of device config changes.");
- checkConfigAndScheduleDownloads();
- });
+ TcLog.v(TAG, "Try to schedule model download work because of device config changes.");
+ scheduleDownloadWork();
}
/** Clean up internal states on destroying. */
@@ -173,96 +158,58 @@
}
/**
- * Check DeviceConfig and schedule new model download requests synchronously. This method is
- * synchronized and contains blocking operations, only call it in a background executor.
+ * Dumps the internal state for debugging.
+ *
+ * @param printWriter writer to write dumped states
*/
- private void checkConfigAndScheduleDownloads() {
- synchronized (lock) {
- WorkManager workManager = WorkManager.getInstance(appContext);
- List<Locale.LanguageRange> languageRanges =
- Locale.LanguageRange.parse(LocaleList.getAdjustedDefault().toLanguageTags());
- for (String modelType : ModelType.values()) {
- // Notice: Be careful of the Locale.lookupTag() matching logic: 1) it will convert the tag
- // to lower-case only; 2) it matches tags from tail to head and does not allow missing
- // pieces. E.g. if your system locale is zh-hans-cn, it won't match zh-cn.
- String bestTag =
- Locale.lookupTag(languageRanges, settings.getLanguageTagsForManifestURL(modelType));
- String localeTag = bestTag != null ? bestTag : UNIVERSAL_LOCALE_TAG;
- TcLog.v(TAG, String.format("Checking model type: %s, best tag: %s", modelType, localeTag));
-
- // One manifest url can uniquely identify a model in the world
- String manifestUrl = settings.getManifestURL(modelType, localeTag);
- if (TextUtils.isEmpty(manifestUrl)) {
- continue;
- }
-
- // Query the history for this url. Stop if we handled this url before.
- // TODO(licha): We may skip downloads incorrectly if we switch locales back and forth.
- WorkQuery workQuery = WorkQuery.Builder.fromTags(ImmutableList.of(manifestUrl)).build();
- try {
- List<WorkInfo> workInfos = workManager.getWorkInfos(workQuery).get();
- if (!workInfos.isEmpty()) {
- TcLog.v(
- TAG,
- String.format(
- "Target manifest has an existing state of: %s. Skip.",
- workInfos.get(0).getState().name()));
- continue;
- }
- } catch (ExecutionException | InterruptedException e) {
- TcLog.e(TAG, "Failed to query queued requests. Ignore and continue.", e);
- }
-
- NetworkType networkType =
- Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
- .or(NetworkType.UNMETERED);
- OneTimeWorkRequest downloadRequest =
- new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
- .setInputData(
- NewModelDownloadWorker.createInputData(
- modelType, localeTag, manifestUrl, settings.getModelDownloadMaxAttempts()))
- .addTag(manifestUrl)
- .setConstraints(
- new Constraints.Builder()
- .setRequiredNetworkType(networkType)
- .setRequiresBatteryNotLow(true)
- .setRequiresStorageNotLow(true)
- .build())
- .setBackoffCriteria(
- BackoffPolicy.EXPONENTIAL,
- settings.getModelDownloadBackoffDelayInMillis(),
- MILLISECONDS)
- .keepResultsForAtLeast(Duration.ofDays(DAYS_TO_KEEP_THE_DOWNLOAD_RESULT))
- .build();
-
- // When we enqueue a new request, existing pending request in the same queue will be
- // cancelled. With this, device will be able to abort previous unfinished downloads
- // (e.g. 711) when a fresher model is already(e.g. v712).
- try {
- // Block until we enqueue the request successfully (to avoid potential race condition)
- workManager
- .enqueueUniqueWork(
- formatUniqueWorkName(modelType, localeTag),
- ExistingWorkPolicy.REPLACE,
- downloadRequest)
- .getResult()
- .get();
- TextClassifierDownloadLogger.downloadSceduled(modelType, manifestUrl);
- TcLog.d(TAG, "Download scheduled: " + manifestUrl);
- } catch (ExecutionException | InterruptedException e) {
- TextClassifierDownloadLogger.downloadFailedAndAbort(
- modelType,
- manifestUrl,
- ModelDownloadException.FAILED_TO_SCHEDULE,
- /* runAttemptCount= */ 0);
- TcLog.e(TAG, "Failed to enqueue a request", e);
- }
- }
- }
+ public void dump(IndentingPrintWriter printWriter) {
+ printWriter.println("ModelDownloadManager:");
+ printWriter.increaseIndent();
+ downloadedModelManager.dump(printWriter);
+ printWriter.decreaseIndent();
}
- /** Formats unique work name for WorkManager. */
- static String formatUniqueWorkName(@ModelType.ModelTypeDef String modelType, String localeTag) {
- return String.format("%s-%s", modelType, localeTag);
+ /**
+ * Enqueue an idempotent work to check device configs and download model files if necessary.
+ *
+ * <p>At any time there will only be at most one work running. If a work is already pending or
+ * running, the newly scheduled work will be appended as a child of that work.
+ */
+ private void scheduleDownloadWork() {
+ NetworkType networkType =
+ Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
+ .or(NetworkType.UNMETERED);
+ OneTimeWorkRequest downloadRequest =
+ new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
+ .setConstraints(
+ new Constraints.Builder()
+ .setRequiredNetworkType(networkType)
+ .setRequiresBatteryNotLow(true)
+ .setRequiresStorageNotLow(true)
+ .build())
+ .setBackoffCriteria(
+ BackoffPolicy.EXPONENTIAL,
+ settings.getModelDownloadBackoffDelayInMillis(),
+ MILLISECONDS)
+ .build();
+ ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
+ WorkManager.getInstance(appContext)
+ .enqueueUniqueWork(
+ UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
+ .getResult();
+ Futures.addCallback(
+ enqueueResultFuture,
+ new FutureCallback<Operation.State.SUCCESS>() {
+ @Override
+ public void onSuccess(Operation.State.SUCCESS unused) {
+ TcLog.v(TAG, "Download work scheduled.");
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ TcLog.e(TAG, "Failed to schedule download work: ", t);
+ }
+ },
+ executorService);
}
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloader.java b/java/src/com/android/textclassifier/downloader/ModelDownloader.java
index 7c0c32d..7e22d99 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloader.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloader.java
@@ -35,12 +35,13 @@
/**
* Downloads a model file and validate it based on given model info.
*
- * <p>The file should be in the cache folder. Returns the File if succeed. If the download or
+ * <p>The file should be in the target folder. Returns the File if succeed. If the download or
* validation fails, the unfinished model file should be cleaned up. Failures should be wrapped
* inside a {@link ModelDownloadException} and throw.
*
+ * @param targetDir the target directory for the downloaded model
* @param modelInfo the model information in manifest used for downloading and validation
* @return the downloaded model file
*/
- ListenableFuture<File> downloadModel(ModelManifest.Model modelInfo);
+ ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model modelInfo);
}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
index 564cf67..366364f 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
@@ -84,9 +84,8 @@
}
@Override
- public ListenableFuture<File> downloadModel(ModelManifest.Model model) {
- File modelFile =
- new File(context.getCacheDir(), model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model");
+ public ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model model) {
+ File modelFile = new File(targetDir, model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model");
ListenableFuture<File> modelFileFuture =
Futures.transform(
download(URI.create(model.getUrl()), modelFile),
@@ -100,7 +99,7 @@
new FutureCallback<File>() {
@Override
public void onSuccess(File pendingModelFile) {
- TcLog.d(TAG, "Download mode successfully: " + pendingModelFile.getAbsolutePath());
+ TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
}
@Override
diff --git a/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java
index 0388165..436bfe4 100644
--- a/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java
+++ b/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java
@@ -17,167 +17,269 @@
package com.android.textclassifier.downloader;
import android.content.Context;
-import androidx.work.Data;
+import android.util.ArrayMap;
+import android.util.Pair;
import androidx.work.ListenableWorker;
import androidx.work.WorkerParameters;
-import com.android.textclassifier.common.ModelFileManager;
import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.ModelType.ModelTypeDef;
import com.android.textclassifier.common.TextClassifierServiceExecutors;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.FluentFuture;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
-import java.io.File;
-import java.nio.file.Files;
-import java.nio.file.StandardCopyOption;
-import java.util.concurrent.ExecutorService;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import java.util.ArrayList;
+import java.util.Locale;
// TODO(licha): Rename this to ModelDownloadWorker.
-// TODO(licha): Consider whether we should let the worker handle all locales/model types.
/** The WorkManager worker to download models for TextClassifierService. */
public final class NewModelDownloadWorker extends ListenableWorker {
private static final String TAG = "NewModelDownloadWorker";
- private static final int MAX_DOWNLOAD_ATTEMPTS_DEFAULT = 5;
- static final String DATA_MODEL_TYPE_KEY = "NewDownloadWorker_modelType";
- static final String DATA_LOCALE_TAG_KEY = "NewDownloadWorker_localeTag";
- static final String DATA_MANIFEST_URL_KEY = "NewDownloadWorker_manifestUrl";
- static final String DATA_MAX_DOWNLOAD_ATTEMPTS_KEY = "NewDownloadWorker_maxDownloadAttempts";
-
- @ModelTypeDef private final String modelType;
- private final String manifestUrl;
- private final int maxDownloadAttempts;
-
- private final ExecutorService executorService;
+ private final ListeningExecutorService executorService;
private final ModelDownloader downloader;
- private final File modelDownloaderDir;
- private final Runnable postDownloadCleanUpRunnable;
+ private final DownloadedModelManager downloadedModelManager;
+ private final TextClassifierSettings settings;
+
+ private final Object lock = new Object();
+
+ @GuardedBy("lock")
+ private final ArrayMap<String, ListenableFuture<Void>> pendingDownloads;
+
+ private ImmutableMap<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls;
public NewModelDownloadWorker(Context context, WorkerParameters workerParams) {
super(context, workerParams);
-
- this.modelType = Preconditions.checkNotNull(getInputData().getString(DATA_MODEL_TYPE_KEY));
- this.manifestUrl = Preconditions.checkNotNull(getInputData().getString(DATA_MANIFEST_URL_KEY));
- this.maxDownloadAttempts =
- getInputData().getInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, MAX_DOWNLOAD_ATTEMPTS_DEFAULT);
-
this.executorService = TextClassifierServiceExecutors.getDownloaderExecutor();
this.downloader = new ModelDownloaderImpl(context, executorService);
- ModelFileManager modelFileManager = new ModelFileManager(context, new TextClassifierSettings());
- this.modelDownloaderDir = modelFileManager.getModelDownloaderDir();
- this.postDownloadCleanUpRunnable = modelFileManager::deleteUnusedModelFiles;
+ this.downloadedModelManager = DownloadedModelManagerImpl.getInstance(context);
+ this.settings = new TextClassifierSettings();
+ this.pendingDownloads = new ArrayMap<>();
+ this.modelTypeToLocaleTagAndManifestUrls = null;
}
@VisibleForTesting
NewModelDownloadWorker(
Context context,
WorkerParameters workerParams,
- ExecutorService executorService,
+ ListeningExecutorService executorService,
ModelDownloader modelDownloader,
- File modelDownloaderDir,
- Runnable postDownloadCleanUpRunnable) {
+ DownloadedModelManager downloadedModelManager,
+ TextClassifierSettings settings) {
super(context, workerParams);
-
- this.modelType = Preconditions.checkNotNull(getInputData().getString(DATA_MODEL_TYPE_KEY));
- this.manifestUrl = Preconditions.checkNotNull(getInputData().getString(DATA_MANIFEST_URL_KEY));
- this.maxDownloadAttempts =
- getInputData().getInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, MAX_DOWNLOAD_ATTEMPTS_DEFAULT);
-
this.executorService = executorService;
this.downloader = modelDownloader;
- this.modelDownloaderDir = modelDownloaderDir;
- this.postDownloadCleanUpRunnable = postDownloadCleanUpRunnable;
+ this.downloadedModelManager = downloadedModelManager;
+ this.settings = settings;
+ this.pendingDownloads = new ArrayMap<>();
+ this.modelTypeToLocaleTagAndManifestUrls = null;
}
@Override
public final ListenableFuture<ListenableWorker.Result> startWork() {
- if (getRunAttemptCount() >= maxDownloadAttempts) {
- TcLog.d(TAG, "Max attempt reached. Abort download task.");
- TextClassifierDownloadLogger.downloadFailedAndAbort(
- modelType,
- manifestUrl,
- // TODO(licha): Add a new failure reason for this
- ModelDownloadException.UNKNOWN_FAILURE_REASON,
- getRunAttemptCount());
+ // Notice: startWork() is invoked on the main thread
+ if (!settings.isModelDownloadManagerEnabled()) {
+ TcLog.e(TAG, "Model Downloader is disabled. Abort the work.");
return Futures.immediateFuture(ListenableWorker.Result.failure());
}
- FluentFuture<ListenableWorker.Result> resultFuture =
- FluentFuture.from(downloader.downloadManifest(manifestUrl))
- .transformAsync(
- manifest -> {
- // TODO(licha): put this in a function to improve the readability
- ModelManifest.Model modelInfo = manifest.getModels(0);
- File targetModelFile =
- new File(
- modelDownloaderDir,
- formatFileNameByModelTypeAndUrl(modelType, modelInfo.getUrl()));
- if (targetModelFile.exists()) {
- TcLog.d(
- TAG,
- "Target model file already exists. Skip download and reuse it: "
- + targetModelFile.getAbsolutePath());
- TextClassifierDownloadLogger.downloadSucceeded(
- modelType, manifestUrl, getRunAttemptCount());
- return Futures.immediateFuture(ListenableWorker.Result.success());
- } else {
- return Futures.transform(
- downloadAndMoveModel(targetModelFile, modelInfo),
- unused -> {
- TextClassifierDownloadLogger.downloadSucceeded(
- modelType, manifestUrl, getRunAttemptCount());
- return ListenableWorker.Result.success();
- },
- executorService);
- }
- },
- executorService)
- .catching(
- Throwable.class,
- e -> {
- TcLog.e(TAG, "Download attempt failed.", e);
- int errorCode =
- (e instanceof ModelDownloadException)
- ? ((ModelDownloadException) e).getErrorCode()
- : ModelDownloadException.UNKNOWN_FAILURE_REASON;
- // Retry until reach max allowed attempts (attempt starts from 0)
- // The backoff time between two tries will grow exponentially (i.e. 30s, 1min,
- // 2min, 4min). This is configurable when building the request.
- TextClassifierDownloadLogger.downloadFailedAndRetry(
- modelType, manifestUrl, errorCode, getRunAttemptCount());
- return ListenableWorker.Result.retry();
- },
- executorService);
- resultFuture.addListener(postDownloadCleanUpRunnable, executorService);
- return resultFuture;
+ TcLog.v(TAG, "Start download work...");
+ if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
+ TcLog.d(TAG, "Max attempt reached. Abort download work.");
+ return Futures.immediateFuture(ListenableWorker.Result.failure());
+ }
+
+ return FluentFuture.from(Futures.submitAsync(this::checkAndDownloadModels, executorService))
+ .transform(
+ allSucceeded -> {
+ Preconditions.checkNotNull(modelTypeToLocaleTagAndManifestUrls);
+ downloadedModelManager.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+ TcLog.v(TAG, "Download work completed. Succeeded: " + allSucceeded);
+ return allSucceeded
+ ? ListenableWorker.Result.success()
+ : ListenableWorker.Result.retry();
+ },
+ executorService)
+ .catching(
+ Throwable.class,
+ t -> {
+ TcLog.e(TAG, "Unexpected Exception during downloading: ", t);
+ return ListenableWorker.Result.retry();
+ },
+ executorService);
}
- private ListenableFuture<Void> downloadAndMoveModel(
- File targetModelFile, ModelManifest.Model modelInfo) {
- return Futures.transform(
- downloader.downloadModel(modelInfo),
- pendingModelFile -> {
- try {
- if (!modelDownloaderDir.exists()) {
- modelDownloaderDir.mkdirs();
- }
- Files.move(
- pendingModelFile.toPath(),
- targetModelFile.toPath(),
- StandardCopyOption.ATOMIC_MOVE,
- StandardCopyOption.REPLACE_EXISTING);
- TcLog.d(
- TAG, "Model file downloaded successfully: " + targetModelFile.getAbsolutePath());
- return null;
- } catch (Exception e) {
- pendingModelFile.delete();
- throw new ModelDownloadException(ModelDownloadException.FAILED_TO_MOVE_MODEL, e);
- }
- },
- executorService);
+ /**
+ * Check device config and dispatch download tasks for all modelTypes.
+ *
+ * <p>Download tasks will be combined and logged after completion. Return true if all tasks
+ * succeeded
+ */
+ private ListenableFuture<Boolean> checkAndDownloadModels() {
+ Locale primaryLocale = Locale.getDefault();
+ ArrayList<ListenableFuture<Boolean>> downloadResultFutures = new ArrayList<>();
+ ImmutableMap.Builder<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrlsBuilder =
+ ImmutableMap.builder();
+ for (String modelType : ModelType.values()) {
+ Pair<String, String> bestLocaleTagAndManifestUrl =
+ LocaleUtils.lookupBestLocaleTagAndManifestUrl(modelType, primaryLocale, settings);
+ if (bestLocaleTagAndManifestUrl == null) {
+ TcLog.w(
+ TAG,
+ String.format(
+ "No suitable manifest for %s, %s", modelType, primaryLocale.toLanguageTag()));
+ continue;
+ }
+ modelTypeToLocaleTagAndManifestUrlsBuilder.put(modelType, bestLocaleTagAndManifestUrl);
+ String bestLocaleTag = bestLocaleTagAndManifestUrl.first;
+ String manifestUrl = bestLocaleTagAndManifestUrl.second;
+ TcLog.v(
+ TAG,
+ String.format(
+ "model type: %s, device locale tag: %s, best locale tag: %s, manifest url: %s",
+ modelType, primaryLocale.toLanguageTag(), bestLocaleTag, manifestUrl));
+ if (!shouldDownloadManifest(modelType, bestLocaleTag, manifestUrl)) {
+ continue;
+ }
+ downloadResultFutures.add(downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl));
+ }
+ modelTypeToLocaleTagAndManifestUrls = modelTypeToLocaleTagAndManifestUrlsBuilder.build();
+
+ return Futures.whenAllComplete(downloadResultFutures)
+ .call(
+ () -> {
+ TcLog.v(TAG, "All Download Tasks Completed");
+ boolean allSucceeded = true;
+ for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
+ allSucceeded &= Futures.getDone(downloadResultFuture);
+ }
+ return allSucceeded;
+ },
+ executorService);
+ }
+
+ private boolean shouldDownloadManifest(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
+ if (downloadedManifest == null) {
+ return true;
+ }
+ if (downloadedManifest.getStatus() == Manifest.STATUS_FAILED) {
+ if (downloadedManifest.getFailureCounts() >= settings.getManifestDownloadMaxAttempts()) {
+ TcLog.w(
+ TAG,
+ String.format(
+ "Manifest failed too many times, stop retrying: %s %d",
+ manifestUrl, downloadedManifest.getFailureCounts()));
+ return false;
+ } else {
+ return true;
+ }
+ }
+ ManifestEnrollment manifestEnrollment =
+ downloadedModelManager.getManifestEnrollment(modelType, localeTag);
+ return manifestEnrollment == null || !manifestUrl.equals(manifestEnrollment.getManifestUrl());
+ }
+
+ /**
+ * Downloads a single manifest and models configured inside it.
+ *
+ * <p>The returned future should always resolve to a ManifestDownloadResult as we catch all
+ * exceptions.
+ */
+ private ListenableFuture<Boolean> downloadManifestAndRegister(
+ @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+ return FluentFuture.from(downloadManifest(manifestUrl))
+ .transform(
+ unused -> {
+ downloadedModelManager.registerManifestEnrollment(modelType, localeTag, manifestUrl);
+ TextClassifierDownloadLogger.downloadSucceeded(
+ modelType, manifestUrl, getRunAttemptCount());
+ TcLog.v(TAG, "Manifest downloaded and registered: " + manifestUrl);
+ return true;
+ },
+ executorService)
+ .catching(
+ Throwable.class,
+ t -> {
+ downloadedModelManager.registerManifestDownloadFailure(manifestUrl);
+ int errorCode = ModelDownloadException.UNKNOWN_FAILURE_REASON;
+ if (t instanceof ModelDownloadException) {
+ errorCode = ((ModelDownloadException) t).getErrorCode();
+ }
+ TcLog.e(TAG, "Failed to download manfiest: " + manifestUrl, t);
+ TextClassifierDownloadLogger.downloadFailedAndRetry(
+ modelType, manifestUrl, errorCode, getRunAttemptCount());
+ return false;
+ },
+ executorService);
+ }
+
+ // Download a manifest and its models, and register it to Manifest table.
+ private ListenableFuture<Void> downloadManifest(String manifestUrl) {
+ synchronized (lock) {
+ Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
+ if (downloadedManifest != null
+ && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
+ TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl);
+ return Futures.immediateVoidFuture();
+ }
+ if (pendingDownloads.containsKey(manifestUrl)) {
+ return pendingDownloads.get(manifestUrl);
+ }
+ ListenableFuture<Void> manfiestDownloadFuture =
+ FluentFuture.from(downloader.downloadManifest(manifestUrl))
+ .transformAsync(
+ manifest -> {
+ ModelManifest.Model modelInfo = manifest.getModels(0);
+ return Futures.transform(
+ downloadModel(modelInfo), unused -> modelInfo, executorService);
+ },
+ executorService)
+ .transform(
+ modelInfo -> {
+ downloadedModelManager.registerManifest(manifestUrl, modelInfo.getUrl());
+ return null;
+ },
+ executorService);
+ pendingDownloads.put(manifestUrl, manfiestDownloadFuture);
+ return manfiestDownloadFuture;
+ }
+ }
+ // Download a model and register it into Model table.
+ private ListenableFuture<Void> downloadModel(ModelManifest.Model modelInfo) {
+ String modelUrl = modelInfo.getUrl();
+ synchronized (lock) {
+ Model downloadedModel = downloadedModelManager.getModel(modelUrl);
+ if (downloadedModel != null) {
+ TcLog.d(TAG, "Model file already exists: " + downloadedModel.getModelPath());
+ return Futures.immediateVoidFuture();
+ }
+ if (pendingDownloads.containsKey(modelUrl)) {
+ return pendingDownloads.get(modelUrl);
+ }
+ ListenableFuture<Void> modelDownloadFuture =
+ FluentFuture.from(
+ downloader.downloadModel(
+ downloadedModelManager.getModelDownloaderDir(), modelInfo))
+ .transform(
+ modelFile -> {
+ downloadedModelManager.registerModel(modelUrl, modelFile.getAbsolutePath());
+ TcLog.v(TAG, "Model File downloaded: " + modelUrl);
+ return null;
+ },
+ executorService);
+ pendingDownloads.put(modelUrl, modelDownloadFuture);
+ return modelDownloadFuture;
+ }
}
/**
@@ -187,39 +289,6 @@
*/
@Override
public final void onStopped() {
- TcLog.d(TAG, String.format("Stop download: %s, attempt:%d", manifestUrl, getRunAttemptCount()));
- TextClassifierDownloadLogger.downloadFailedAndRetry(
- modelType, manifestUrl, ModelDownloadException.WORKER_STOPPED, getRunAttemptCount());
- }
-
- static final Data createInputData(
- @ModelTypeDef String modelType,
- String localeTag,
- String manifestUrl,
- int maxDownloadAttempts) {
- return new Data.Builder()
- .putString(DATA_MODEL_TYPE_KEY, modelType)
- .putString(DATA_LOCALE_TAG_KEY, localeTag)
- .putString(DATA_MANIFEST_URL_KEY, manifestUrl)
- .putInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, maxDownloadAttempts)
- .build();
- }
-
- /**
- * Returns the absolute path to download a model.
- *
- * <p>Each file's name is uniquely formatted based on its unique remote manifest URL.
- *
- * @param modelType the type of the model image to download
- * @param url the unique remote url
- */
- static String formatFileNameByModelTypeAndUrl(
- @ModelType.ModelTypeDef String modelType, String url) {
- // TODO(licha): Consider preserving the folder hierarchy of the URL
- String fileMidName = url.replaceAll("[^A-Za-z0-9]", "_");
- if (fileMidName.startsWith("https___")) {
- fileMidName = fileMidName.substring("https___".length());
- }
- return String.format("%s.%s.model", modelType, fileMidName);
+ TcLog.d(TAG, String.format("Stop download. Attempt:%d", getRunAttemptCount()));
}
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 746931b..ad4a197 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -42,7 +42,6 @@
import com.android.os.AtomsProto.TextClassifierApiUsageReported;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
-import com.android.textclassifier.common.ModelFileManager;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.statsd.StatsdTestUtils;
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
new file mode 100644
index 0000000..3c0319f
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -0,0 +1,394 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.android.textclassifier.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManager.RegularFilePatternMatchLister;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import com.google.common.io.Files;
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class ModelFileManagerTest {
+ private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
+
+ @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+ @Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
+
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+
+ private File rootTestDir;
+ private ModelFileManager modelFileManager;
+
+ @Before
+ public void setup() {
+ MockitoAnnotations.initMocks(this);
+
+ rootTestDir =
+ new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
+ rootTestDir.mkdirs();
+ modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ new TextClassifierSettings(mockDeviceConfig));
+ setDefaultLocalesRule.set(new LocaleList(DEFAULT_LOCALE));
+ }
+
+ @After
+ public void removeTestDir() {
+ recursiveDelete(rootTestDir);
+ }
+
+ @Test
+ public void annotatorModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
+ }
+
+ @Test
+ public void actionsModelPreloaded() {
+ verifyModelPreloadedAsAsset(
+ ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
+ }
+
+ @Test
+ public void langIdModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
+ }
+
+ private void verifyModelPreloadedAsAsset(
+ @ModelTypeDef String modelType, String expectedModelPath) {
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
+ List<ModelFile> assetFiles =
+ modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
+
+ assertThat(assetFiles).hasSize(1);
+ assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
+ }
+
+ @Test
+ public void findBestModel_versionCode() {
+ ModelFileManager.ModelFile olderModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile newerModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 2);
+ ModelFileManager modelFileManager = createModelFileManager(olderModelFile, newerModelFile);
+
+ ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
+ assertThat(bestModelFile).isEqualTo(newerModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageDependentModelIsPreferred() {
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile languageDependentModelFile =
+ createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_noMatchedLanguageModel() {
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile languageDependentModelFile =
+ createModelFile("zh-hk", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_languageIsMoreImportantThanVersion() {
+ ModelFileManager.ModelFile matchButOlderModel =
+ createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
+ ModelFileManager.ModelFile mismatchButNewerModel = createModelFile("zh-hk", /* version */ 2);
+ ModelFileManager modelFileManager =
+ createModelFileManager(matchButOlderModel, mismatchButNewerModel);
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
+ assertThat(bestModelFile).isEqualTo(matchButOlderModel);
+ }
+
+ @Test
+ public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage() {
+ setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh"));
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match() {
+ setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh-hk"));
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh"));
+ assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch() {
+ setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh"));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided() {
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile nonPrimaryLocaleModelFile =
+ createModelFile("zh-hk", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, nonPrimaryLocaleModelFile);
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, /* localePreferences= */ null);
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided() {
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+
+ ModelFileManager.ModelFile languageIndependentModelFile =
+ createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+ ModelFileManager.ModelFile nonPrimaryLocalePreferenceModelFile =
+ createModelFile("zh-hk", /* version */ 1);
+ ModelFileManager modelFileManager =
+ createModelFileManager(languageIndependentModelFile, nonPrimaryLocalePreferenceModelFile);
+
+ ModelFileManager.ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(
+ MODEL_TYPE,
+ new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+ assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+ }
+
+ @Test
+ public void modelFileEquals() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA).isEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_different() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA).isNotEqualTo(modelB);
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_languageDependentIsBetter() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_isPreferredTo_version() {
+ ModelFileManager.ModelFile modelA =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ModelFileManager.ModelFile modelB =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+ assertThat(modelA.isPreferredTo(modelB)).isTrue();
+ }
+
+ @Test
+ public void modelFile_toModelInfo() {
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ModelInfo modelInfo = modelFile.toModelInfo();
+
+ assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
+ }
+
+ @Test
+ public void modelFile_toModelInfos() {
+ ModelFile englishModelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
+ ModelFile japaneseModelFile =
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+ ImmutableList<Optional<ModelInfo>> modelInfos =
+ ModelFileManager.ModelFile.toModelInfos(
+ Optional.of(englishModelFile), Optional.of(japaneseModelFile));
+
+ assertThat(
+ modelInfos.stream()
+ .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
+ .collect(Collectors.toList()))
+ .containsExactly("en_v1", "ja_v2")
+ .inOrder();
+ }
+
+ @Test
+ public void regularFileFullMatchLister() throws IOException {
+ File modelFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
+ File wrongFile = new File(rootTestDir, "wrong.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
+
+ RegularFileFullMatchLister regularFileFullMatchLister =
+ new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
+ ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).hasSize(1);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ }
+
+ @Test
+ public void regularFilePatternMatchLister() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+ File modelFile2 = new File(rootTestDir, "annotator.fr.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
+ File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
+
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).hasSize(2);
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ assertThat(listedModels.get(1).isAsset).isFalse();
+ assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
+ .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
+ }
+
+ @Test
+ public void regularFilePatternMatchLister_disabled() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).isEmpty();
+ }
+
+ private ModelFileManager createModelFileManager(ModelFile... modelFiles) {
+ return new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.copyOf(modelFiles)));
+ }
+
+ private ModelFileManager.ModelFile createModelFile(String supportedLocaleTags, int version) {
+ return new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ new File(rootTestDir, String.format("%s-%d", supportedLocaleTags, version))
+ .getAbsolutePath(),
+ version,
+ supportedLocaleTags,
+ /* isAsset= */ false);
+ }
+
+ private static void recursiveDelete(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ recursiveDelete(innerFile);
+ }
+ }
+ f.delete();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index 5c1d95e..48f71f3 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -17,8 +17,7 @@
package com.android.textclassifier;
import android.content.Context;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
import com.android.textclassifier.common.ModelType;
import com.google.common.collect.ImmutableList;
import java.io.File;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 81aa832..5c36008 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -41,7 +41,6 @@
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
-import com.android.textclassifier.common.ModelFileManager;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.testing.FakeContextBuilder;
import com.google.common.collect.ImmutableList;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
deleted file mode 100644
index 40838ac..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
+++ /dev/null
@@ -1,507 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.common;
-
-import static com.android.textclassifier.common.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
-import static com.google.common.truth.Truth.assertThat;
-
-import android.os.LocaleList;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.ext.junit.runners.AndroidJUnit4;
-import androidx.test.filters.SmallTest;
-import com.android.textclassifier.TestDataUtils;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
-import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
-import com.android.textclassifier.common.ModelFileManager.RegularFilePatternMatchLister;
-import com.android.textclassifier.common.ModelType.ModelTypeDef;
-import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
-import com.android.textclassifier.testing.SetDefaultLocalesRule;
-import com.google.common.base.Optional;
-import com.google.common.collect.ImmutableList;
-import com.google.common.io.Files;
-import java.io.File;
-import java.io.IOException;
-import java.util.List;
-import java.util.Locale;
-import java.util.stream.Collectors;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public final class ModelFileManagerTest {
- private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
-
- @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
-
- @Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
-
- @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
-
- private File rootTestDir;
- private ModelFileManager modelFileManager;
-
- @Before
- public void setup() {
- MockitoAnnotations.initMocks(this);
-
- rootTestDir =
- new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
- rootTestDir.mkdirs();
- modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- new TextClassifierSettings(mockDeviceConfig));
- }
-
- @After
- public void removeTestDir() {
- recursiveDelete(rootTestDir);
- }
-
- @Test
- public void annotatorModelPreloaded() {
- verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
- }
-
- @Test
- public void actionsModelPreloaded() {
- verifyModelPreloadedAsAsset(
- ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
- }
-
- @Test
- public void langIdModelPreloaded() {
- verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
- }
-
- private void verifyModelPreloadedAsAsset(
- @ModelTypeDef String modelType, String expectedModelPath) {
- List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
- List<ModelFile> assetFiles =
- modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
-
- assertThat(assetFiles).hasSize(1);
- assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
- }
-
- @Test
- public void findBestModel_versionCode() {
- ModelFileManager.ModelFile olderModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile newerModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 2,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(olderModelFile, newerModelFile)));
-
- ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
- assertThat(bestModelFile).isEqualTo(newerModelFile);
- }
-
- @Test
- public void findBestModel_languageDependentModelIsPreferred() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 2,
- DEFAULT_LOCALE.toLanguageTag(),
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType ->
- ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
-
- ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
- assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 2,
- DEFAULT_LOCALE.toLanguageTag(),
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType ->
- ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_languageIsMoreImportantThanVersion() {
- ModelFileManager.ModelFile matchButOlderModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- "fr",
- /* isAsset= */ false);
- ModelFileManager.ModelFile mismatchButNewerModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 1,
- "ja",
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType -> ImmutableList.of(matchButOlderModel, mismatchButNewerModel)));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
- assertThat(bestModelFile).isEqualTo(matchButOlderModel);
- }
-
- @Test
- public void findBestModel_preferMatchedLocaleModel() {
- ModelFileManager.ModelFile matchLocaleModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "a").getAbsolutePath(),
- /* version= */ 1,
- "ja",
- /* isAsset= */ false);
- ModelFileManager.ModelFile languageIndependentModel =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File(rootTestDir, "b").getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(
- modelType -> ImmutableList.of(matchLocaleModel, languageIndependentModel)));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
-
- assertThat(bestModelFile).isEqualTo(matchLocaleModel);
- }
-
- @Test
- public void deleteUnusedModelFiles_olderModelDeleted() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "ja", /* isAsset= */ false);
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isFalse();
- assertThat(model2.exists()).isTrue();
- }
-
- @Test
- public void deleteUnusedModelFiles_languageIndependentOlderModelDeleted() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- model1.getAbsolutePath(),
- /* version= */ 1,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- model2.getAbsolutePath(),
- /* version= */ 2,
- LANGUAGE_INDEPENDENT,
- /* isAsset= */ false);
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isFalse();
- assertThat(model2.exists()).isTrue();
- }
-
- @Test
- public void deleteUnusedModelFiles_modelOnlySupportingLocalesNotInListDeleted() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 1, "en", /* isAsset= */ false);
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isTrue();
- assertThat(model2.exists()).isFalse();
- }
-
- @Test
- public void deleteUnusedModelFiles_multiLocalesInLocaleList() throws Exception {
- File model1 = new File(rootTestDir, "model1.fb");
- model1.createNewFile();
- File model2 = new File(rootTestDir, "model2.fb");
- model2.createNewFile();
- ModelFileManager.ModelFile modelFile1 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelFile2 =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "en", /* isAsset= */ false);
- setDefaultLocalesRule.set(
- new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("en")));
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isTrue();
- assertThat(model2.exists()).isTrue();
- }
-
- @Test
- public void deleteUnusedModelFiles_readOnlyModelsUntouched() throws Exception {
- File readOnlyDir = new File(rootTestDir, "read_only/");
- readOnlyDir.mkdirs();
- File model1 = new File(readOnlyDir, "model1.fb");
- model1.createNewFile();
- readOnlyDir.setWritable(false);
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager modelFileManager =
- new ModelFileManager(
- ApplicationProvider.getApplicationContext(),
- ImmutableList.of(modelType -> ImmutableList.of(modelFile)));
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
-
- modelFileManager.deleteUnusedModelFiles();
-
- assertThat(model1.exists()).isTrue();
- }
-
- @Test
- public void modelFileEquals() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
- assertThat(modelA).isEqualTo(modelB);
- }
-
- @Test
- public void modelFile_different() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
-
- assertThat(modelA).isNotEqualTo(modelB);
- }
-
- @Test
- public void modelFile_isPreferredTo_languageDependentIsBetter() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void modelFile_isPreferredTo_version() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
- ModelFileManager.ModelFile modelB =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
-
- assertThat(modelA.isPreferredTo(modelB)).isTrue();
- }
-
- @Test
- public void modelFile_toModelInfo() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
- ModelInfo modelInfo = modelFile.toModelInfo();
-
- assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
- }
-
- @Test
- public void modelFile_toModelInfos() {
- ModelFile englishModelFile =
- new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
- ModelFile japaneseModelFile =
- new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
- ImmutableList<Optional<ModelInfo>> modelInfos =
- ModelFileManager.ModelFile.toModelInfos(
- Optional.of(englishModelFile), Optional.of(japaneseModelFile));
-
- assertThat(
- modelInfos.stream()
- .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
- .collect(Collectors.toList()))
- .containsExactly("en_v1", "ja_v2")
- .inOrder();
- }
-
- @Test
- public void regularFileFullMatchLister() throws IOException {
- File modelFile = new File(rootTestDir, "test.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
- File wrongFile = new File(rootTestDir, "wrong.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
-
- RegularFileFullMatchLister regularFileFullMatchLister =
- new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
- ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
-
- assertThat(listedModels).hasSize(1);
- assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
- assertThat(listedModels.get(0).isAsset).isFalse();
- }
-
- @Test
- public void regularFilePatternMatchLister() throws IOException {
- File modelFile1 = new File(rootTestDir, "annotator.en.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
- File modelFile2 = new File(rootTestDir, "annotator.fr.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
- File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
-
- RegularFilePatternMatchLister regularFilePatternMatchLister =
- new RegularFilePatternMatchLister(
- MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
- ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
-
- assertThat(listedModels).hasSize(2);
- assertThat(listedModels.get(0).isAsset).isFalse();
- assertThat(listedModels.get(1).isAsset).isFalse();
- assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
- .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
- }
-
- @Test
- public void regularFilePatternMatchLister_disabled() throws IOException {
- File modelFile1 = new File(rootTestDir, "annotator.en.model");
- Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
-
- RegularFilePatternMatchLister regularFilePatternMatchLister =
- new RegularFilePatternMatchLister(
- MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
- ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
-
- assertThat(listedModels).isEmpty();
- }
-
- private static void recursiveDelete(File f) {
- if (f.isDirectory()) {
- for (File innerFile : f.listFiles()) {
- recursiveDelete(innerFile);
- }
- }
- f.delete();
- }
-}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java
new file mode 100644
index 0000000..835f50b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java
@@ -0,0 +1,398 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.testng.Assert.expectThrows;
+
+import android.content.Context;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.util.List;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+public class DownloadedModelDatabaseTest {
+ private static final String MODEL_URL = "https://model.url";
+ private static final String MODEL_URL_2 = "https://model2.url";
+ private static final String MODEL_PATH = "/data/test.model";
+ private static final String MODEL_PATH_2 = "/data/test.model2";
+ private static final String MANIFEST_URL = "https://manifest.url";
+ private static final String MANIFEST_URL_2 = "https://manifest2.url";
+ private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+ private static final String MODEL_TYPE_2 = ModelType.ACTIONS_SUGGESTIONS;
+ private static final String LOCALE_TAG = "zh";
+
+ private DownloadedModelDatabase db;
+
+ @Before
+ public void createDb() {
+ Context context = ApplicationProvider.getApplicationContext();
+ db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+ }
+
+ @After
+ public void closeDb() throws IOException {
+ db.close();
+ }
+
+ @Test
+ public void insertModelAndRead() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ List<Model> models = db.dao().queryAllModels();
+ assertThat(models).containsExactly(model);
+ }
+
+ @Test
+ public void insertModelAndDelete() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ db.dao().deleteModels(ImmutableList.of(model));
+ List<Model> models = db.dao().queryAllModels();
+ assertThat(models).isEmpty();
+ }
+
+ @Test
+ public void insertManifestAndRead() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ List<Manifest> manifests = db.dao().queryAllManifests();
+ assertThat(manifests).containsExactly(manifest);
+ }
+
+ @Test
+ public void insertManifestAndDelete() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().deleteManifests(ImmutableList.of(manifest));
+ List<Manifest> manifests = db.dao().queryAllManifests();
+ assertThat(manifests).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndRead() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).containsExactly(manifestModelCrossRef);
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndDelete() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ db.dao().deleteManifestModelCrossRefs(ImmutableList.of(manifestModelCrossRef));
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndDeleteManifest() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ db.dao().deleteManifests(ImmutableList.of(manifest)); // ON CASCADE
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefAndDeleteModel() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ db.dao().deleteModels(ImmutableList.of(model)); // ON CASCADE
+ List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+ assertThat(manifestModelCrossRefs).isEmpty();
+ }
+
+ @Test
+ public void insertManifestModelCrossRefWithoutManifest() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ expectThrows(Throwable.class, () -> db.dao().insert(manifestModelCrossRef));
+ }
+
+ @Test
+ public void insertManifestModelCrossRefWithoutModel() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ expectThrows(Throwable.class, () -> db.dao().insert(manifestModelCrossRef));
+ }
+
+ @Test
+ public void insertManifestEnrollmentAndRead() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+ assertThat(manifestEnrollments).containsExactly(manifestEnrollment);
+ }
+
+ @Test
+ public void insertManifestEnrollmentAndDelete() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ db.dao().deleteManifestEnrollments(ImmutableList.of(manifestEnrollment));
+ List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+ assertThat(manifestEnrollments).isEmpty();
+ }
+
+ @Test
+ public void insertManifestEnrollmentAndDeleteManifest() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ db.dao().deleteManifests(ImmutableList.of(manifest));
+ List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+ assertThat(manifestEnrollments).isEmpty();
+ }
+
+ @Test
+ public void insertManifestEnrollmentWithoutManifest() throws Exception {
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ expectThrows(Throwable.class, () -> db.dao().insert(manifestEnrollment));
+ }
+
+ @Test
+ public void insertModelViewAndRead() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ ManifestModelCrossRef manifestModelCrossRef =
+ ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+ db.dao().insert(manifestModelCrossRef);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ List<ModelView> modelViews = db.dao().queryAllModelViews();
+ ModelView modelView = Iterables.getOnlyElement(modelViews);
+ assertThat(modelView.getManifestEnrollment()).isEqualTo(manifestEnrollment);
+ assertThat(modelView.getModel()).isEqualTo(model);
+ }
+
+ @Test
+ public void queryModelWithModelUrl() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+ db.dao().insert(model2);
+
+ assertThat(db.dao().queryModelWithModelUrl(MODEL_URL)).containsExactly(model);
+ assertThat(db.dao().queryModelWithModelUrl(MODEL_URL_2)).containsExactly(model2);
+ }
+
+ @Test
+ public void queryManifestWithManifestUrl() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest2);
+
+ assertThat(db.dao().queryManifestWithManifestUrl(MANIFEST_URL)).containsExactly(manifest);
+ assertThat(db.dao().queryManifestWithManifestUrl(MANIFEST_URL_2)).containsExactly(manifest2);
+ }
+
+ @Test
+ public void queryManifestEnrollmentWithModelTypeAndLocaleTag() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest2);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+ ManifestEnrollment manifestEnrollment2 =
+ ManifestEnrollment.create(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+ db.dao().insert(manifestEnrollment2);
+
+ assertThat(db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(MODEL_TYPE, LOCALE_TAG))
+ .containsExactly(manifestEnrollment);
+ assertThat(db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(MODEL_TYPE_2, LOCALE_TAG))
+ .containsExactly(manifestEnrollment2);
+ }
+
+ @Test
+ public void insertManifestAndModelCrossRef() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ }
+
+ @Test
+ public void increaseManifestFailureCounts() throws Exception {
+ db.dao().increaseManifestFailureCounts(MODEL_URL);
+ Manifest manifest = Iterables.getOnlyElement(db.dao().queryManifestWithManifestUrl(MODEL_URL));
+ assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(manifest.getFailureCounts()).isEqualTo(1);
+ db.dao().increaseManifestFailureCounts(MODEL_URL);
+ manifest = Iterables.getOnlyElement(db.dao().queryManifestWithManifestUrl(MODEL_URL));
+ assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(manifest.getFailureCounts()).isEqualTo(2);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_unusedManifestAndUnusedModel() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+ db.dao().insert(model2);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest2);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL_2, MODEL_URL_2);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_unusedManifestAndSharedModel() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest2);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL_2, MODEL_URL);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_failedManifest() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ }
+
+ @Test
+ public void deleteUnusedManifestsAndModels_unusedModels() throws Exception {
+ Model model = Model.create(MODEL_URL, MODEL_PATH);
+ db.dao().insert(model);
+ Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+ db.dao().insert(model2);
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+ db.dao().insert(manifest);
+ db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+ ManifestEnrollment manifestEnrollment =
+ ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ db.dao().insert(manifestEnrollment);
+
+ db.dao().deleteUnusedManifestsAndModels();
+ assertThat(db.dao().queryAllModels()).containsExactly(model);
+ }
+
+ @Test
+ public void deleteUnusedManifestFailureRecords() throws Exception {
+ Manifest manifest =
+ Manifest.create(MANIFEST_URL, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest);
+ Manifest manifest2 =
+ Manifest.create(MANIFEST_URL_2, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+ db.dao().insert(manifest2);
+
+ db.dao().deleteUnusedManifestFailureRecords(ImmutableList.of(MANIFEST_URL));
+ assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java
new file mode 100644
index 0000000..2715a05
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java
@@ -0,0 +1,280 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import android.util.Pair;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableMap;
+import java.io.File;
+import java.util.Map;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class DownloadedModelManagerImplTest {
+
+ private File modelDownloaderDir;
+ private DownloadedModelDatabase db;
+ private DownloadedModelManagerImpl downloadedModelManagerImpl;
+ private TestingDeviceConfig deviceConfig;
+ private TextClassifierSettings settings;
+
+ @Before
+ public void setUp() {
+ Context context = ApplicationProvider.getApplicationContext();
+ modelDownloaderDir = new File(context.getFilesDir(), "test_dir");
+ modelDownloaderDir.mkdirs();
+ deviceConfig = new TestingDeviceConfig();
+ settings = new TextClassifierSettings(deviceConfig);
+ db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+ downloadedModelManagerImpl =
+ DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings);
+ }
+
+ @After
+ public void cleanUp() {
+ DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
+ db.close();
+ }
+
+ @Test
+ public void getModelDownloaderDir() throws Exception {
+ modelDownloaderDir.delete();
+ assertThat(downloadedModelManagerImpl.getModelDownloaderDir().exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.getModelDownloaderDir()).isEqualTo(modelDownloaderDir);
+ }
+
+ @Test
+ public void listModels_cacheNotInitialized() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+ registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathEn"), new File("modelPathZh"));
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.LANG_ID)).isEmpty();
+ }
+
+ @Test
+ public void listModels_doNotListBlockedModels() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+ registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+ deviceConfig.setConfig(
+ TextClassifierSettings.MODEL_URL_BLOCKLIST,
+ String.format(
+ "%s%s%s",
+ "modelUrlEn", TextClassifierSettings.MODEL_URL_BLOCKLIST_SEPARATOR, "modelUrlXX"));
+
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathZh"));
+ }
+
+ @Test
+ public void listModels_cacheNotUpdatedUnlessOnDownloadCompleted() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathEn"));
+
+ registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathEn"));
+
+ Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+ ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("zh", "manifestUrlZh"));
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(new File("modelPathZh"));
+ }
+
+ @Test
+ public void getModel() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+ assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
+ .isEqualTo("modelPath");
+ assertThat(downloadedModelManagerImpl.getModel("modelUrl2")).isNull();
+ }
+
+ @Test
+ public void getManifest() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
+ }
+
+ @Test
+ public void getManifestEnrollment() throws Exception {
+ registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+ assertThat(
+ downloadedModelManagerImpl
+ .getManifestEnrollment(ModelType.ANNOTATOR, "en")
+ .getManifestUrl())
+ .isEqualTo("manifestUrl");
+ assertThat(downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "zh"))
+ .isNull();
+ }
+
+ @Test
+ public void registerModel() throws Exception {
+ downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+
+ assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
+ .isEqualTo("modelPath");
+ }
+
+ @Test
+ public void registerManifest() throws Exception {
+ downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+ downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
+
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
+ }
+
+ @Test
+ public void registerManifestDownloadFailure() throws Exception {
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl");
+
+ Manifest manifest = downloadedModelManagerImpl.getManifest("manifestUrl");
+ assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(manifest.getFailureCounts()).isEqualTo(1);
+ }
+
+ @Test
+ public void registerManifestEnrollment() throws Exception {
+ downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+ downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
+ downloadedModelManagerImpl.registerManifestEnrollment(ModelType.ANNOTATOR, "en", "manifestUrl");
+
+ ManifestEnrollment manifestEnrollment =
+ downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "en");
+ assertThat(manifestEnrollment.getModelType()).isEqualTo(ModelType.ANNOTATOR);
+ assertThat(manifestEnrollment.getLocaleTag()).isEqualTo("en");
+ assertThat(manifestEnrollment.getManifestUrl()).isEqualTo("manifestUrl");
+ }
+
+ @Test
+ public void onDownloadCompleted_newModelDownloaded() throws Exception {
+ Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+ ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+
+ modelTypeToLocaleTagAndManifestUrls =
+ ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl2"));
+ File modelFile2 = new File(modelDownloaderDir, "modelFile2");
+ modelFile2.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+ assertThat(modelFile1.exists()).isFalse();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile2);
+ }
+
+ @Test
+ public void onDownloadCompleted_newModelDownloadFailed() throws Exception {
+ Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+ ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+
+ modelTypeToLocaleTagAndManifestUrls =
+ ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl2"));
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+ }
+
+ @Test
+ public void onDownloadCompleted_flatUnset() throws Exception {
+ Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+ ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+ File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+ modelFile1.createNewFile();
+ registerManifestToDB(
+ ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+ assertThat(modelFile1.exists()).isTrue();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+ .containsExactly(modelFile1);
+
+ modelTypeToLocaleTagAndManifestUrls = ImmutableMap.of();
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+ assertThat(modelFile1.exists()).isFalse();
+ assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)).isEmpty();
+ }
+
+ @Test
+ public void onDownloadCompleted_cleanUpFailureRecords() throws Exception {
+ Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+ ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl1");
+ downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
+ downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl1").getStatus())
+ .isEqualTo(Manifest.STATUS_FAILED);
+ assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
+ }
+
+ private void registerManifestToDB(
+ @ModelTypeDef String modelType,
+ String localeTag,
+ String manifestUrl,
+ String modelUrl,
+ String modelPath) {
+ db.dao().insert(Model.create(modelUrl, modelPath));
+ db.dao()
+ .insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
+ db.dao().insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
+ db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
index 1337130..37394e6 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
@@ -16,80 +16,21 @@
package com.android.textclassifier.downloader;
-import android.content.Context;
-import androidx.work.ListenableWorker;
import androidx.work.WorkInfo;
import androidx.work.WorkManager;
import androidx.work.WorkQuery;
-import androidx.work.WorkerParameters;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
-import com.google.common.util.concurrent.Futures;
-import com.google.common.util.concurrent.ListenableFuture;
import java.io.File;
import java.util.List;
/** Utils for downloader logic testing. */
final class DownloaderTestUtils {
- /** One unique queue holds at most one request at one time. Returns null if no WorkInfo found. */
- public static WorkInfo queryTheOnlyWorkInfo(WorkManager workManager, String queueName)
+ public static List<WorkInfo> queryWorkInfos(WorkManager workManager, String queueName)
throws Exception {
WorkQuery workQuery =
WorkQuery.Builder.fromUniqueWorkNames(ImmutableList.of(queueName)).build();
- List<WorkInfo> workInfos = workManager.getWorkInfos(workQuery).get();
- if (workInfos.isEmpty()) {
- return null;
- } else {
- return Iterables.getOnlyElement(workInfos);
- }
- }
-
- /**
- * Completes immediately with the pre-set result. If it's not retry, the result will also include
- * the input Data as its output Data.
- */
- public static final class TestWorker extends ListenableWorker {
- private static Result expectedResult;
-
- public TestWorker(Context context, WorkerParameters workerParams) {
- super(context, workerParams);
- }
-
- @Override
- public ListenableFuture<ListenableWorker.Result> startWork() {
- if (expectedResult == null) {
- return Futures.immediateFailedFuture(new Exception("no expected result"));
- }
- ListenableWorker.Result result;
- switch (expectedResult) {
- case SUCCESS:
- result = ListenableWorker.Result.success(getInputData());
- break;
- case FAILURE:
- result = ListenableWorker.Result.failure(getInputData());
- break;
- case RETRY:
- result = ListenableWorker.Result.retry();
- break;
- default:
- throw new IllegalStateException("illegal result");
- }
- // Reset expected result
- expectedResult = null;
- return Futures.immediateFuture(result);
- }
-
- /** Sets the expected worker result in a static variable. Will be cleaned up after reading. */
- public static void setExpectedResult(Result expectedResult) {
- TestWorker.expectedResult = expectedResult;
- }
-
- public enum Result {
- SUCCESS,
- FAILURE,
- RETRY;
- }
+ return workManager.getWorkInfos(workQuery).get();
}
// MoreFiles#deleteRecursively is not available for Android guava.
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java
new file mode 100644
index 0000000..a553c51
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableList;
+import java.util.Locale;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class LocaleUtilsTest {
+ private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+ private TestingDeviceConfig deviceConfig;
+ private TextClassifierSettings settings;
+
+ @Before
+ public void setUp() {
+ deviceConfig = new TestingDeviceConfig();
+ settings = new TextClassifierSettings(deviceConfig);
+ }
+
+ @Test
+ public void lookupBestLocaleTag_simpleMatch() {
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en"), ImmutableList.of("en", "zh")))
+ .isEqualTo("en");
+ }
+
+ @Test
+ public void lookupBestLocaleTag_noMatch() {
+ assertThat(LocaleUtils.lookupBestLocaleTag(Locale.forLanguageTag("en"), ImmutableList.of("zh")))
+ .isNull();
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(Locale.forLanguageTag("en"), ImmutableList.of("en-uk")))
+ .isNull();
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en-uk")))
+ .isNull();
+ }
+
+ @Test
+ public void lookupBestLocaleTag_partialMatch() {
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en", "zh")))
+ .isEqualTo("en");
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en", "en-us")))
+ .isEqualTo("en-us");
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en-US"), ImmutableList.of("en", "en-uk")))
+ .isEqualTo("en");
+ }
+
+ @Test
+ public void lookupBestLocaleTag_universalMatch() {
+ assertThat(
+ LocaleUtils.lookupBestLocaleTag(
+ Locale.forLanguageTag("en"),
+ ImmutableList.of("zh", LocaleUtils.UNIVERSAL_LOCALE_TAG)))
+ .isEqualTo(LocaleUtils.UNIVERSAL_LOCALE_TAG);
+ }
+
+ @Test
+ public void lookupBestLocaleTagAndManifestUrl_found() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, "en", "url_1");
+ Pair<String, String> pair =
+ LocaleUtils.lookupBestLocaleTagAndManifestUrl(
+ MODEL_TYPE, Locale.forLanguageTag("en"), settings);
+ assertThat(pair.first).isEqualTo("en");
+ assertThat(pair.second).isEqualTo("url_1");
+ }
+
+ @Test
+ public void lookupBestLocaleTagAndManifestUrl_notFound() throws Exception {
+ Pair<String, String> pair =
+ LocaleUtils.lookupBestLocaleTagAndManifestUrl(
+ MODEL_TYPE, Locale.forLanguageTag("en"), settings);
+ assertThat(pair).isNull();
+ }
+
+ private void setUpManifestUrl(
+ @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
+ String deviceConfigFlag =
+ String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
+ deviceConfig.setConfig(deviceConfigFlag, url);
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
index 09593c2..351d00b 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -17,19 +17,15 @@
package com.android.textclassifier.downloader;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
import android.content.Context;
import android.os.LocaleList;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
-import androidx.work.ExistingWorkPolicy;
-import androidx.work.OneTimeWorkRequest;
import androidx.work.WorkInfo;
import androidx.work.WorkManager;
-import androidx.work.testing.TestDriver;
import androidx.work.testing.WorkManagerTestInitHelper;
-import com.android.os.AtomsProto.TextClassifierDownloadReported;
-import com.android.textclassifier.common.ModelFileManager;
import com.android.textclassifier.common.ModelType;
import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
@@ -38,27 +34,23 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.util.List;
import java.util.Locale;
+import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
@RunWith(AndroidJUnit4.class)
public final class ModelDownloadManagerTest {
- private static final String MANIFEST_URL =
- "https://www.gstatic.com/android/text_classifier/x/v123/en.fb.manifest";
- private static final String MANIFEST_URL_2 =
- "https://www.gstatic.com/android/text_classifier/y/v456/zh.fb.manifest";
- // Parameterized test is not yet supported for instrumentation test
+ private static final String MODEL_PATH = "/data/test.model";
@ModelType.ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
- private static final TextClassifierDownloadReported.ModelType MODEL_TYPE_ATOM =
- TextClassifierDownloadReported.ModelType.ANNOTATOR;
private static final String LOCALE_TAG = "en";
- private static final String LOCALE_TAG_2 = "zh";
- private static final String LOCALE_UNIVERSAL_TAG = ModelDownloadManager.UNIVERSAL_LOCALE_TAG;
private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG));
@Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
@@ -70,8 +62,8 @@
// TODO(licha): Maybe we can just use the real TextClassifierSettings
private TestingDeviceConfig deviceConfig;
private WorkManager workManager;
- private TestDriver workManagerTestDriver;
private ModelDownloadManager downloadManager;
+ @Mock DownloadedModelManager downloadedModelManager;
@Before
public void setUp() {
@@ -81,276 +73,84 @@
this.deviceConfig = new TestingDeviceConfig();
this.workManager = WorkManager.getInstance(context);
- this.workManagerTestDriver = WorkManagerTestInitHelper.getTestDriver(context);
- ModelFileManager modelFileManager = new ModelFileManager(context, ImmutableList.of());
this.downloadManager =
new ModelDownloadManager(
context,
- DownloaderTestUtils.TestWorker.class,
- modelFileManager,
+ NewModelDownloadWorker.class,
+ downloadedModelManager,
new TextClassifierSettings(deviceConfig),
MoreExecutors.newDirectExecutorService());
+
setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
}
@After
public void tearDown() {
+ workManager.cancelUniqueWork(ModelDownloadManager.UNIQUE_QUEUE_NAME);
DownloaderTestUtils.deleteRecursively(
ApplicationProvider.getApplicationContext().getFilesDir());
}
@Test
public void onTextClassifierServiceCreated_requestEnqueued() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
downloadManager.onTextClassifierServiceCreated();
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
+ WorkInfo workInfo =
+ Iterables.getOnlyElement(
+ DownloaderTestUtils.queryWorkInfos(
+ workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
}
@Test
public void onLocaleChanged_requestEnqueued() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
downloadManager.onLocaleChanged();
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
+ WorkInfo workInfo =
+ Iterables.getOnlyElement(
+ DownloaderTestUtils.queryWorkInfos(
+ workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
}
@Test
public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
downloadManager.onTextClassifierDeviceConfigChanged();
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
+ WorkInfo workInfo =
+ Iterables.getOnlyElement(
+ DownloaderTestUtils.queryWorkInfos(
+ workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
}
@Test
public void onTextClassifierDeviceConfigChanged_downloaderDisabled() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
downloadManager.onTextClassifierDeviceConfigChanged();
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(workInfo).isNull();
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_flagNotSet() throws Exception {
- downloadManager.onTextClassifierDeviceConfigChanged();
-
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(workInfo).isNull();
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_skipManifestProcessedBefore() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
- // Simulates a previous model download task
- OneTimeWorkRequest modelDownloadRequest =
- new OneTimeWorkRequest.Builder(DownloaderTestUtils.TestWorker.class)
- .addTag(MANIFEST_URL)
- .build();
- DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
- workManager
- .enqueueUniqueWork(
- ModelDownloadManager.formatUniqueWorkName(MODEL_TYPE, LOCALE_TAG),
- ExistingWorkPolicy.REPLACE,
- modelDownloadRequest)
- .getResult()
- .get();
-
- // Assert the model download work succeeded
- WorkInfo oldWorkInfo =
- DownloaderTestUtils.queryTheOnlyWorkInfo(
- workManager, ModelDownloadManager.formatUniqueWorkName(MODEL_TYPE, LOCALE_TAG));
- assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.SUCCEEDED);
-
- // Trigger the config check
- downloadManager.onTextClassifierDeviceConfigChanged();
- WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(newWorkInfo.getId()).isEqualTo(oldWorkInfo.getId());
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_scheduleWorkForAllModelTypes() throws Exception {
- for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
- setUpManifestUrl(modelType, LOCALE_TAG, modelType + MANIFEST_URL);
- }
- downloadManager.onTextClassifierDeviceConfigChanged();
-
- for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
- WorkInfo workInfo = queryTheOnlyWorkInfo(modelType, LOCALE_TAG);
- assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
- }
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_checkIsIdempotent() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
- downloadManager.onTextClassifierDeviceConfigChanged();
- WorkInfo oldWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- downloadManager.onTextClassifierDeviceConfigChanged();
- WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-
- // Will not schedule multiple times, still the same WorkInfo
- assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
- assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
- assertThat(oldWorkInfo.getId()).isEqualTo(newWorkInfo.getId());
- assertThat(oldWorkInfo.getTags()).containsExactlyElementsIn(newWorkInfo.getTags());
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_newWorkReplaceOldWork() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
- downloadManager.onTextClassifierDeviceConfigChanged();
- WorkInfo oldWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL_2);
- downloadManager.onTextClassifierDeviceConfigChanged();
- WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-
- // oldWorkInfo will be replaced with the newWorkInfo
- assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
- assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
- assertThat(oldWorkInfo.getId()).isNotEqualTo(newWorkInfo.getId());
- assertThat(oldWorkInfo.getTags()).contains(MANIFEST_URL);
- assertThat(newWorkInfo.getTags()).contains(MANIFEST_URL_2);
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_workerSucceeded() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
- downloadManager.onTextClassifierDeviceConfigChanged();
-
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
- DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
- workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
-
- WorkInfo succeededWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(succeededWorkInfo.getId()).isEqualTo(workInfo.getId());
- assertThat(succeededWorkInfo.getState()).isEqualTo(WorkInfo.State.SUCCEEDED);
assertThat(
- succeededWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_MODEL_TYPE_KEY))
- .isEqualTo(MODEL_TYPE);
- assertThat(
- succeededWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_LOCALE_TAG_KEY))
- .isEqualTo(LOCALE_TAG);
- assertThat(
- succeededWorkInfo
- .getOutputData()
- .getString(NewModelDownloadWorker.DATA_MANIFEST_URL_KEY))
- .isEqualTo(MANIFEST_URL);
+ DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME))
+ .isEmpty();
}
@Test
- public void onTextClassifierDeviceConfigChanged_workerFailed() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ public void onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork() throws Exception {
downloadManager.onTextClassifierDeviceConfigChanged();
+ downloadManager.onTextClassifierDeviceConfigChanged();
+ List<WorkInfo> workInfos =
+ DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME);
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
- DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.FAILURE);
- workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
-
- WorkInfo failedWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(failedWorkInfo.getId()).isEqualTo(workInfo.getId());
- assertThat(failedWorkInfo.getState()).isEqualTo(WorkInfo.State.FAILED);
- assertThat(failedWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_MODEL_TYPE_KEY))
- .isEqualTo(MODEL_TYPE);
- assertThat(failedWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_LOCALE_TAG_KEY))
- .isEqualTo(LOCALE_TAG);
- assertThat(
- failedWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_MANIFEST_URL_KEY))
- .isEqualTo(MANIFEST_URL);
+ assertThat(workInfos.stream().map(WorkInfo::getState).collect(Collectors.toList()))
+ .containsExactly(WorkInfo.State.ENQUEUED, WorkInfo.State.BLOCKED);
}
@Test
- public void onTextClassifierDeviceConfigChanged_workerRetried() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
- downloadManager.onTextClassifierDeviceConfigChanged();
+ public void listDownloadedModels() throws Exception {
+ File modelFile = new File(MODEL_PATH);
+ when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(ImmutableList.of(modelFile));
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
- DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.RETRY);
- workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
-
- WorkInfo retryWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(retryWorkInfo.getId()).isEqualTo(workInfo.getId());
- assertThat(retryWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
- assertThat(retryWorkInfo.getRunAttemptCount()).isEqualTo(1);
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_chooseTheBestLocaleTag() throws Exception {
- // System default locale: zh-hant-hk
- setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("zh-hant-hk")));
-
- // All configured locale tags
- setUpManifestUrl(MODEL_TYPE, "zh-hant", MANIFEST_URL); // best match
- setUpManifestUrl(MODEL_TYPE, "zh", MANIFEST_URL_2); // too general
- setUpManifestUrl(MODEL_TYPE, "zh-hk", MANIFEST_URL_2); // missing script
- setUpManifestUrl(MODEL_TYPE, "zh-hans-hk", MANIFEST_URL_2); // incorrect script
- setUpManifestUrl(MODEL_TYPE, "es-hant-hk", MANIFEST_URL_2); // incorrect language
-
- downloadManager.onTextClassifierDeviceConfigChanged();
-
- // The downloader choose: zh-hant
- assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hant").getState())
- .isEqualTo(WorkInfo.State.ENQUEUED);
-
- assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh")).isNull();
- assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hk")).isNull();
- assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hans-hk")).isNull();
- assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "es-hant-hk")).isNull();
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_useUniversalModelIfNoMatchedTag()
- throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL);
- setUpManifestUrl(MODEL_TYPE, LOCALE_UNIVERSAL_TAG, MANIFEST_URL_2);
- downloadManager.onTextClassifierDeviceConfigChanged();
-
- assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG_2)).isNull();
-
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_UNIVERSAL_TAG);
- assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
- assertThat(workInfo.getTags()).contains(MANIFEST_URL_2);
- }
-
- @Test
- public void onTextClassifierDeviceConfigChanged_logAfterDownloadScheduled() throws Exception {
- setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
- downloadManager.onTextClassifierDeviceConfigChanged();
-
- WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
- assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
- // verify log
- TextClassifierDownloadReported atom = Iterables.getOnlyElement(loggerTestRule.getLoggedAtoms());
- assertThat(atom.getDownloadStatus())
- .isEqualTo(TextClassifierDownloadReported.DownloadStatus.SCHEDULED);
- assertThat(atom.getModelType()).isEqualTo(MODEL_TYPE_ATOM);
- assertThat(atom.getUrlSuffix()).isEqualTo(MANIFEST_URL);
- }
-
- private void setUpManifestUrl(
- @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
- String deviceConfigFlag =
- String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
- deviceConfig.setConfig(deviceConfigFlag, url);
- }
-
- /** One unique queue holds at most one request at one time. Returns null if no WorkInfo found. */
- private WorkInfo queryTheOnlyWorkInfo(@ModelType.ModelTypeDef String modelType, String localeTag)
- throws Exception {
- return DownloaderTestUtils.queryTheOnlyWorkInfo(
- workManager, ModelDownloadManager.formatUniqueWorkName(modelType, localeTag));
+ assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile);
}
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
index 8c6d13f..47e7fb6 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
@@ -24,10 +24,10 @@
import androidx.test.core.app.ApplicationProvider;
import com.android.textclassifier.downloader.TestModelDownloaderService.DownloadResult;
import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
import java.io.File;
import java.nio.file.Files;
import java.util.concurrent.CancellationException;
-import java.util.concurrent.Executors;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -53,18 +53,23 @@
ModelManifest.newBuilder().addModels(MODEL_PROTO).build();
private ModelDownloaderImpl modelDownloaderImpl;
+ private File modelDownloaderDir;
@Before
public void setUp() {
Context context = ApplicationProvider.getApplicationContext();
this.modelDownloaderImpl =
new ModelDownloaderImpl(
- context, Executors.newSingleThreadExecutor(), TestModelDownloaderService.class);
+ context, MoreExecutors.newDirectExecutorService(), TestModelDownloaderService.class);
+ this.modelDownloaderDir = new File(context.getFilesDir(), "downloader");
+ this.modelDownloaderDir.mkdirs();
+
+ TestModelDownloaderService.reset();
}
@After
public void tearDown() {
- TestModelDownloaderService.reset();
+ DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
}
@Test
@@ -148,9 +153,12 @@
assertThat(TestModelDownloaderService.isBound()).isFalse();
TestModelDownloaderService.setBindSucceed(true);
- TestModelDownloaderService.setDownloadResult(MANIFEST_URL, DownloadResult.BLOCKING, null);
+ TestModelDownloaderService.setDownloadResult(MANIFEST_URL, DownloadResult.DO_NOTHING, null);
ListenableFuture<ModelManifest> manifestFuture =
modelDownloaderImpl.downloadManifest(MANIFEST_URL);
+
+ assertThat(TestModelDownloaderService.getOnBindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isTrue();
manifestFuture.cancel(true);
expectThrows(CancellationException.class, manifestFuture::get);
@@ -165,7 +173,8 @@
assertThat(TestModelDownloaderService.isBound()).isFalse();
TestModelDownloaderService.setBindSucceed(false);
- ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
Throwable t = expectThrows(Throwable.class, modelFuture::get);
assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
@@ -184,9 +193,11 @@
TestModelDownloaderService.setBindSucceed(true);
TestModelDownloaderService.setDownloadResult(
MODEL_URL, DownloadResult.SUCCEEDED, MODEL_CONTENT_BYTES);
- ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
File modelFile = modelFuture.get();
+ assertThat(modelFile.getParentFile()).isEqualTo(modelDownloaderDir);
assertThat(Files.readAllBytes(modelFile.toPath())).isEqualTo(MODEL_CONTENT_BYTES);
assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
assertThat(TestModelDownloaderService.isBound()).isFalse();
@@ -200,7 +211,8 @@
TestModelDownloaderService.setBindSucceed(true);
TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.FAILED, null);
- ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
Throwable t = expectThrows(Throwable.class, modelFuture::get);
assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
@@ -220,7 +232,8 @@
TestModelDownloaderService.setBindSucceed(true);
TestModelDownloaderService.setDownloadResult(
MODEL_URL, DownloadResult.SUCCEEDED, "randomString".getBytes());
- ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
Throwable t = expectThrows(Throwable.class, modelFuture::get);
assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
@@ -237,8 +250,12 @@
assertThat(TestModelDownloaderService.isBound()).isFalse();
TestModelDownloaderService.setBindSucceed(true);
- TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.BLOCKING, null);
- ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+ TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.DO_NOTHING, null);
+ ListenableFuture<File> modelFuture =
+ modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
+
+ assertThat(TestModelDownloaderService.getOnBindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isTrue();
modelFuture.cancel(true);
expectThrows(CancellationException.class, modelFuture::get);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java
index 34b6660..4a2741b 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java
@@ -17,10 +17,13 @@
package com.android.textclassifier.downloader;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import android.content.Context;
+import android.os.LocaleList;
+import androidx.room.Room;
import androidx.test.core.app.ApplicationProvider;
import androidx.work.ListenableWorker;
import androidx.work.WorkerFactory;
@@ -30,11 +33,17 @@
import com.android.os.AtomsProto.TextClassifierDownloadReported.DownloadStatus;
import com.android.os.AtomsProto.TextClassifierDownloadReported.FailureReason;
import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.android.textclassifier.testing.TestingDeviceConfig;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.MoreExecutors;
import java.io.File;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@@ -47,35 +56,54 @@
@RunWith(JUnit4.class)
public final class NewModelDownloadWorkerTest {
private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+ private static final String MODEL_TYPE_2 = ModelType.ACTIONS_SUGGESTIONS;
private static final TextClassifierDownloadReported.ModelType MODEL_TYPE_ATOM =
TextClassifierDownloadReported.ModelType.ANNOTATOR;
private static final String LOCALE_TAG = "en";
private static final String MANIFEST_URL =
"https://www.gstatic.com/android/text_classifier/q/v711/en.fb.manifest";
+ private static final String MANIFEST_URL_2 =
+ "https://www.gstatic.com/android/text_classifier/q/v711/zh.fb.manifest";
private static final String MODEL_URL =
"https://www.gstatic.com/android/text_classifier/q/v711/en.fb";
+ private static final String MODEL_URL_2 =
+ "https://www.gstatic.com/android/text_classifier/q/v711/zh.fb";
private static final int RUN_ATTEMPT_COUNT = 1;
- private static final int MAX_RUN_ATTEMPT_COUNT = 5;
+ private static final int WORKER_MAX_RUN_ATTEMPT_COUNT = 5;
+ private static final int MANIFEST_MAX_ATTEMPT_COUNT = 2;
private static final ModelManifest.Model MODEL_PROTO =
ModelManifest.Model.newBuilder()
.setUrl(MODEL_URL)
.setSizeInBytes(1)
.setFingerprint("fingerprint")
.build();
+ private static final ModelManifest.Model MODEL_PROTO_2 =
+ ModelManifest.Model.newBuilder()
+ .setUrl(MODEL_URL_2)
+ .setSizeInBytes(1)
+ .setFingerprint("fingerprint")
+ .build();
private static final ModelManifest MODEL_MANIFEST_PROTO =
ModelManifest.newBuilder().addModels(MODEL_PROTO).build();
+ private static final ModelManifest MODEL_MANIFEST_PROTO_2 =
+ ModelManifest.newBuilder().addModels(MODEL_PROTO_2).build();
private static final ModelDownloadException FAILED_TO_DOWNLOAD_EXCEPTION =
new ModelDownloadException(
ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER, "failed to download");
private static final FailureReason FAILED_TO_DOWNLOAD_FAILURE_REASON =
TextClassifierDownloadReported.FailureReason.FAILED_TO_DOWNLOAD_OTHER;
-
- private File downloadedModelDir;
- private File pendingModelDir;
- private File targetModelFile;
+ private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG));
@Mock private ModelDownloader modelDownloader;
- @Mock private Runnable postDownloadCleanUpRunnable;
+ private File modelDownloaderDir;
+ private File modelFile;
+ private File modelFile2;
+ private DownloadedModelDatabase db;
+ private DownloadedModelManager downloadedModelManager;
+ private TestingDeviceConfig deviceConfig;
+ private TextClassifierSettings settings;
+
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
@Rule
public final TextClassifierDownloadLoggerTestRule loggerTestRule =
@@ -86,133 +114,219 @@
MockitoAnnotations.initMocks(this);
Context context = ApplicationProvider.getApplicationContext();
- this.downloadedModelDir = new File(context.getCacheDir(), "downloaded");
- this.downloadedModelDir.mkdirs();
- this.pendingModelDir = new File(context.getCacheDir(), "pending");
- this.pendingModelDir.mkdirs();
- this.targetModelFile =
- new File(
- downloadedModelDir,
- NewModelDownloadWorker.formatFileNameByModelTypeAndUrl(MODEL_TYPE, MODEL_URL));
- this.targetModelFile.delete();
+ this.deviceConfig = new TestingDeviceConfig();
+ this.settings = new TextClassifierSettings(deviceConfig);
+ this.modelDownloaderDir = new File(context.getCacheDir(), "downloaded");
+ this.modelDownloaderDir.mkdirs();
+ this.modelFile = new File(modelDownloaderDir, "test.model");
+ this.modelFile2 = new File(modelDownloaderDir, "test2.model");
+ this.db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+ this.downloadedModelManager =
+ DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings);
+
+ setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
+ deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
}
@After
public void cleanUp() {
- DownloaderTestUtils.deleteRecursively(downloadedModelDir);
- DownloaderTestUtils.deleteRecursively(pendingModelDir);
+ db.close();
+ DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
}
@Test
- public void downloadManifest_succeed_downloadModel_succeed_moveModelFile_succeed()
- throws Exception {
+ public void downloadSucceed() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
when(modelDownloader.downloadManifest(MANIFEST_URL))
.thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
- File pendingModelFile = new File(pendingModelDir, "pending.model.file");
- pendingModelFile.createNewFile();
- when(modelDownloader.downloadModel(MODEL_PROTO))
- .thenReturn(Futures.immediateFuture(pendingModelFile));
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
- assertThat(targetModelFile.exists()).isTrue();
- assertThat(pendingModelFile.exists()).isFalse();
- verify(postDownloadCleanUpRunnable).run();
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
}
@Test
- public void downloadManifest_failed() throws Exception {
+ public void downloadSucceed_modelAlreadyExists() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ modelFile.createNewFile();
+ downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+
+ NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
+ }
+
+ @Test
+ public void downloadSucceed_manifestAlreadyExists() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ modelFile.createNewFile();
+ downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+ downloadedModelManager.registerManifest(MANIFEST_URL, MODEL_URL);
+
+ NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
+ }
+
+ @Test
+ public void downloadSucceed_downloadMultipleModels() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_2));
+ modelFile.createNewFile();
+ modelFile2.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_2))
+ .thenReturn(Futures.immediateFuture(modelFile2));
+
+ NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(modelFile2.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile2);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ }
+
+ @Test
+ public void downloadSucceed_shareSingleModelDownloadForMultipleManifest() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+
+ NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile);
+ verify(modelDownloader, times(1)).downloadModel(modelDownloaderDir, MODEL_PROTO);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ }
+
+ @Test
+ public void downloadSucceed_shareManifest() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL);
+ when(modelDownloader.downloadManifest(MANIFEST_URL))
+ .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+ modelFile.createNewFile();
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+ .thenReturn(Futures.immediateFuture(modelFile));
+
+ NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(modelFile.exists()).isTrue();
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+ assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile);
+ verify(modelDownloader, times(1)).downloadManifest(MANIFEST_URL);
+ List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedAtoms();
+ assertThat(atoms).hasSize(2);
+ assertThat(
+ atoms.stream()
+ .map(TextClassifierDownloadReported::getUrlSuffix)
+ .collect(Collectors.toList()))
+ .containsExactly(MANIFEST_URL, MANIFEST_URL);
+ assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+ }
+
+ @Test
+ public void downloadFailed_failedToDownloadManifest() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
when(modelDownloader.downloadManifest(MANIFEST_URL))
.thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
- assertThat(targetModelFile.exists()).isFalse();
- verify(postDownloadCleanUpRunnable).run();
verifyLoggedAtom(
DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FAILED_TO_DOWNLOAD_FAILURE_REASON);
}
@Test
- public void downloadManifest_succeed_downloadModel_failed() throws Exception {
+ public void downloadFailed_failedToDownloadModel() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
when(modelDownloader.downloadManifest(MANIFEST_URL))
.thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
- when(modelDownloader.downloadModel(MODEL_PROTO))
+ when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
.thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
- assertThat(targetModelFile.exists()).isFalse();
- verify(postDownloadCleanUpRunnable).run();
verifyLoggedAtom(
DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FAILED_TO_DOWNLOAD_FAILURE_REASON);
}
@Test
- public void downloadManifest_succeed_downloadModel_alreadyExist() throws Exception {
- when(modelDownloader.downloadManifest(MANIFEST_URL))
- .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
- when(modelDownloader.downloadModel(MODEL_PROTO))
- .thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
- targetModelFile.createNewFile();
-
- NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
- assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
- assertThat(targetModelFile.exists()).isTrue();
- verify(postDownloadCleanUpRunnable).run();
- verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
- }
-
- @Test
- public void downloadManifest_succeed_downloadModel_succeed_moveModelFile_failed()
- throws Exception {
- when(modelDownloader.downloadManifest(MANIFEST_URL))
- .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
- File pendingModelFile = new File(pendingModelDir, "pending.model.file");
- pendingModelFile.createNewFile();
- when(modelDownloader.downloadModel(MODEL_PROTO))
- .thenReturn(Futures.immediateFuture(pendingModelFile));
-
- try {
- downloadedModelDir.setWritable(false);
-
- NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
- assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
- assertThat(targetModelFile.exists()).isFalse();
- assertThat(pendingModelFile.exists()).isFalse();
- verify(postDownloadCleanUpRunnable).run();
- verifyLoggedAtom(
- DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FailureReason.FAILED_TO_MOVE_MODEL);
- } finally {
- downloadedModelDir.setWritable(true);
- }
- }
-
- @Test
- public void reachMaxRunAttempts() throws Exception {
- NewModelDownloadWorker worker = createWorker(MAX_RUN_ATTEMPT_COUNT);
+ public void downloadFailed_reachWorkerMaxRunAttempts() throws Exception {
+ NewModelDownloadWorker worker = createWorker(WORKER_MAX_RUN_ATTEMPT_COUNT);
assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.failure());
- verifyLoggedAtom(
- DownloadStatus.FAILED_AND_ABORT,
- MAX_RUN_ATTEMPT_COUNT,
- FailureReason.UNKNOWN_FAILURE_REASON);
}
@Test
- public void workerStopped() throws Exception {
- NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
- worker.onStopped();
- verifyLoggedAtom(
- DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FailureReason.WORKER_STOPPED);
+ public void downloadSkipped_reachManifestMaxAttempts() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ // Current default max attempts is 2
+ downloadedModelManager.registerManifestDownloadFailure(MANIFEST_URL);
+ downloadedModelManager.registerManifestDownloadFailure(MANIFEST_URL);
+
+ NewModelDownloadWorker worker = createWorker(MANIFEST_MAX_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(loggerTestRule.getLoggedAtoms()).isEmpty();
+ }
+
+ @Test
+ public void downloadSkipped_manifestAlreadyProcessed() throws Exception {
+ setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+ modelFile.createNewFile();
+ downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+ downloadedModelManager.registerManifest(MANIFEST_URL, MODEL_URL);
+ downloadedModelManager.registerManifestEnrollment(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+
+ NewModelDownloadWorker worker = createWorker(MANIFEST_MAX_ATTEMPT_COUNT);
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(loggerTestRule.getLoggedAtoms()).isEmpty();
}
private NewModelDownloadWorker createWorker(int runAttemptCount) {
return TestListenableWorkerBuilder.from(
ApplicationProvider.getApplicationContext(), NewModelDownloadWorker.class)
- .setInputData(
- NewModelDownloadWorker.createInputData(
- MODEL_TYPE, LOCALE_TAG, MANIFEST_URL, MAX_RUN_ATTEMPT_COUNT))
.setRunAttemptCount(runAttemptCount)
.setWorkerFactory(
new WorkerFactory() {
@@ -224,8 +338,8 @@
workerParameters,
MoreExecutors.newDirectExecutorService(),
modelDownloader,
- downloadedModelDir,
- postDownloadCleanUpRunnable);
+ downloadedModelManager,
+ settings);
}
})
.build();
@@ -243,4 +357,11 @@
assertThat(atom.getFailureReason()).isEqualTo(failureReason);
}
}
+
+ private void setUpManifestUrl(
+ @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
+ String deviceConfigFlag =
+ String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
+ deviceConfig.setConfig(deviceConfigFlag, url);
+ }
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
index 172fa68..97a55fb 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
@@ -25,6 +25,7 @@
import java.util.concurrent.CountDownLatch;
import javax.annotation.Nullable;
+// TODO(licha): Find another way to test the service. Those static states can break easily.
/** Test Service of IModelDownloaderService. */
public final class TestModelDownloaderService extends Service {
private static final String TAG = "TestModelDownloaderService";
@@ -38,13 +39,13 @@
public enum DownloadResult {
SUCCEEDED,
FAILED,
- BLOCKING,
DO_NOTHING
}
// Obviously this does not work when considering concurrency, but probably fine for test purpose
private static boolean boundBefore = false;
private static boolean boundNow = false;
+ private static CountDownLatch onBindInvokedLatch = new CountDownLatch(1);
private static CountDownLatch onUnbindInvokedLatch = new CountDownLatch(1);
private static boolean bindSucceed = false;
@@ -60,6 +61,10 @@
return boundNow;
}
+ public static CountDownLatch getOnBindInvokedLatch() {
+ return onBindInvokedLatch;
+ }
+
public static CountDownLatch getOnUnbindInvokedLatch() {
return onUnbindInvokedLatch;
}
@@ -78,26 +83,34 @@
public static void reset() {
boundBefore = false;
boundNow = false;
+ onBindInvokedLatch = new CountDownLatch(1);
onUnbindInvokedLatch = new CountDownLatch(1);
bindSucceed = false;
}
@Override
public IBinder onBind(Intent intent) {
- if (bindSucceed) {
- boundBefore = true;
- boundNow = true;
- return new TestModelDownloaderServiceImpl();
- } else {
- return null;
+ try {
+ if (bindSucceed) {
+ boundBefore = true;
+ boundNow = true;
+ return new TestModelDownloaderServiceImpl();
+ } else {
+ return null;
+ }
+ } finally {
+ onBindInvokedLatch.countDown();
}
}
@Override
public boolean onUnbind(Intent intent) {
- boundNow = false;
- onUnbindInvokedLatch.countDown();
- return false;
+ try {
+ boundNow = false;
+ return false;
+ } finally {
+ onUnbindInvokedLatch.countDown();
+ }
}
private static final class TestModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
@@ -118,10 +131,6 @@
case FAILED:
callback.onFailure(ERROR_CODE, ERROR_MSG);
break;
- case BLOCKING:
- Thread.sleep(1000 * 60L);
- TcLog.w(TAG, "Blocking request returns.");
- break;
case DO_NOTHING:
// Do nothing
}