Export TextClassifier to sc-mainline-prod.

Use Merged-In to skip merging to master. Will resolve the conflict later.

Bug: 187927611
Merged-In: Ic08373aecd85a1c71814e7cc4798b54dc6281bdf
Change-Id: Icb51a375bea23d9767e0241bf50b1178f4382f0f
diff --git a/java/Android.bp b/java/Android.bp
index 30fd2bc..5948a17 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -57,6 +57,10 @@
         "src/**/*.aidl",
     ],
     manifest: "LibNoManifest_AndroidManifest.xml",
+    plugins: [
+        "auto_value_plugin",
+        "androidx.room_room-compiler-plugin",
+    ],
     static_libs: [
         "androidx.core_core",
         "libtextclassifier-java",
@@ -69,6 +73,8 @@
         "textclassifier-statsd",
         "textclassifier-java-proto-lite",
         "androidx.concurrent_concurrent-futures",
+        "auto_value_annotations",
+        "androidx.room_room-runtime",
     ],
     sdk_version: "system_current",
     min_sdk_version: "30",
diff --git a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
index beb155b..4838503 100644
--- a/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
+++ b/java/src/com/android/textclassifier/ActionsSuggestionsHelper.java
@@ -27,7 +27,7 @@
 import android.view.textclassifier.ConversationAction;
 import android.view.textclassifier.ConversationActions;
 import android.view.textclassifier.ConversationActions.Message;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelFile;
 import com.android.textclassifier.common.base.TcLog;
 import com.android.textclassifier.common.intent.LabeledIntent;
 import com.android.textclassifier.common.intent.TemplateIntentFactory;
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 707af6a..49acf2e 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -30,7 +30,6 @@
 import android.view.textclassifier.TextSelection;
 import androidx.annotation.NonNull;
 import androidx.collection.LruCache;
-import com.android.textclassifier.common.ModelFileManager;
 import com.android.textclassifier.common.TextClassifierServiceExecutors;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.base.TcLog;
@@ -88,10 +87,10 @@
     modelDownloadManager =
         new com.android.textclassifier.downloader.ModelDownloadManager(
             injector.getContext().getApplicationContext(),
-            modelFileManager,
             settings,
             TextClassifierServiceExecutors.getDownloaderExecutor());
     modelDownloadManager.onTextClassifierServiceCreated();
+    modelFileManager.addModelDownloaderModels(modelDownloadManager, settings);
     textClassifierApiUsageLogger =
         injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
   }
@@ -209,6 +208,7 @@
     IndentingPrintWriter indentingPrintWriter = new IndentingPrintWriter(writer);
     // TODO(licha): Also dump ModelDownloadManager for debugging
     textClassifier.dump(indentingPrintWriter);
+    modelDownloadManager.dump(indentingPrintWriter);
     dumpImpl(indentingPrintWriter);
     indentingPrintWriter.flush();
   }
diff --git a/java/src/com/android/textclassifier/common/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
similarity index 86%
rename from java/src/com/android/textclassifier/common/ModelFileManager.java
rename to java/src/com/android/textclassifier/ModelFileManager.java
index 406a889..155cf56 100644
--- a/java/src/com/android/textclassifier/common/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -14,19 +14,22 @@
  * limitations under the License.
  */
 
-package com.android.textclassifier.common;
+package com.android.textclassifier;
 
 import android.content.Context;
 import android.content.res.AssetFileDescriptor;
 import android.content.res.AssetManager;
 import android.os.LocaleList;
 import android.os.ParcelFileDescriptor;
-import android.util.ArraySet;
 import androidx.annotation.GuardedBy;
 import androidx.collection.ArrayMap;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.common.ModelType;
 import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.base.TcLog;
 import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.downloader.ModelDownloadManager;
 import com.android.textclassifier.utils.IndentingPrintWriter;
 import com.google.android.textclassifier.ActionsSuggestionsModel;
 import com.google.android.textclassifier.AnnotatorModel;
@@ -59,27 +62,19 @@
 
   private static final String TAG = "ModelFileManager";
 
-  private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
   private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
   private static final String ASSETS_DIR = "textclassifier";
 
-  private final List<ModelFileLister> modelFileListers;
-  private final File modelDownloaderDir;
+  private ImmutableList<ModelFileLister> modelFileListers;
 
   public ModelFileManager(Context context, TextClassifierSettings settings) {
     Preconditions.checkNotNull(context);
     Preconditions.checkNotNull(settings);
 
     AssetManager assetManager = context.getAssets();
-    this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
     modelFileListers =
         ImmutableList.of(
             // Annotator models.
-            new RegularFilePatternMatchLister(
-                ModelType.ANNOTATOR,
-                this.modelDownloaderDir,
-                "annotator\\.(.*)\\.model",
-                settings::isModelDownloadManagerEnabled),
             new RegularFileFullMatchLister(
                 ModelType.ANNOTATOR,
                 new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
@@ -91,11 +86,6 @@
                 "annotator\\.(.*)\\.model",
                 /* isEnabled= */ () -> true),
             // Actions models.
-            new RegularFilePatternMatchLister(
-                ModelType.ACTIONS_SUGGESTIONS,
-                this.modelDownloaderDir,
-                "actions_suggestions\\.(.*)\\.model",
-                settings::isModelDownloadManagerEnabled),
             new RegularFileFullMatchLister(
                 ModelType.ACTIONS_SUGGESTIONS,
                 new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
@@ -107,11 +97,6 @@
                 "actions_suggestions\\.(.*)\\.model",
                 /* isEnabled= */ () -> true),
             // LangID models.
-            new RegularFilePatternMatchLister(
-                ModelType.LANG_ID,
-                this.modelDownloaderDir,
-                "lang_id\\.(.*)\\.model",
-                settings::isModelDownloadManagerEnabled),
             new RegularFileFullMatchLister(
                 ModelType.LANG_ID,
                 new File(CONFIG_UPDATER_DIR, "lang_id.model"),
@@ -126,10 +111,37 @@
 
   @VisibleForTesting
   public ModelFileManager(Context context, List<ModelFileLister> modelFileListers) {
-    this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
     this.modelFileListers = ImmutableList.copyOf(modelFileListers);
   }
 
+  // TODO(licha): Move this to constructor and consider using DownloadedModelManager here
+  /** Enable ModelFileManager to scan and use models downloaded by model downloader. */
+  public void addModelDownloaderModels(
+      ModelDownloadManager modelDownloadManager, TextClassifierSettings settings) {
+    this.modelFileListers =
+        ImmutableList.<ModelFileLister>builder()
+            .addAll(modelFileListers)
+            .add(
+                modelType -> {
+                  ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+                  if (settings.isModelDownloadManagerEnabled()) {
+                    for (File modelFile : modelDownloadManager.listDownloadedModels(modelType)) {
+                      try {
+                        // TODO(licha): Make the ModelFile class public and construct downloader
+                        // model files with locale tag in our internal database
+                        modelFilesBuilder.add(
+                            ModelFile.createFromRegularFile(modelFile, modelType));
+                      } catch (IOException e) {
+                        TcLog.e(
+                            TAG, "Failed to create ModelFile: " + modelFile.getAbsolutePath(), e);
+                      }
+                    }
+                  }
+                  return modelFilesBuilder.build();
+                })
+            .build();
+  }
+
   /**
    * Returns an immutable list of model files listed by the given model files supplier.
    *
@@ -146,6 +158,7 @@
   }
 
   /** Lists model files. */
+  @FunctionalInterface
   public interface ModelFileLister {
     List<ModelFile> list(@ModelTypeDef String modelType);
   }
@@ -328,16 +341,36 @@
   @Nullable
   public ModelFile findBestModelFile(
       @ModelTypeDef String modelType, @Nullable LocaleList localePreferences) {
-    final String languages =
-        localePreferences == null || localePreferences.isEmpty()
-            ? LocaleList.getDefault().toLanguageTags()
-            : localePreferences.toLanguageTags();
-    final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
+    Locale targetLocale =
+        localePreferences != null ? localePreferences.get(0) : Locale.getDefault();
+    return findBestModelFile(modelType, targetLocale);
+  }
 
+  /**
+   * Returns the best model file for the given locale, {@code null} if nothing is found.
+   *
+   * @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
+   * @param targetLocale the preferred locale
+   */
+  @Nullable
+  private ModelFile findBestModelFile(@ModelTypeDef String modelType, Locale targetLocale) {
+    List<Locale.LanguageRange> deviceLanguageRanges =
+        Locale.LanguageRange.parse(LocaleList.getDefault().toLanguageTags());
+    boolean languageIndependentModelOnly = false;
+    if (Locale.lookupTag(deviceLanguageRanges, ImmutableList.of(targetLocale.getLanguage()))
+        == null) {
+      // If the targetLocale's language is not in device locale list, we don't match it to avoid
+      // leaking user language profile to the callers.
+      languageIndependentModelOnly = true;
+    }
+    List<Locale.LanguageRange> targetLanguageRanges =
+        Locale.LanguageRange.parse(targetLocale.toLanguageTag());
     ModelFile bestModel = null;
     for (ModelFile model : listModelFiles(modelType)) {
-      // TODO(licha): update this when we want to support multiple languages
-      if (model.isAnyLanguageSupported(languageRangeList)) {
+      if (languageIndependentModelOnly && !model.languageIndependent) {
+        continue;
+      }
+      if (model.isAnyLanguageSupported(targetLanguageRanges)) {
         if (model.isPreferredTo(bestModel)) {
           bestModel = model;
         }
@@ -347,39 +380,6 @@
   }
 
   /**
-   * Deletes model files that are not preferred for any locales in user's preference.
-   *
-   * <p>This method will be invoked as a clean-up after we download a new model successfully. Race
-   * conditions are hard to avoid because we do not hold locks for files. But it should rarely cause
-   * any issues since it's safe to delete a model file in use (b/c we mmap it to memory).
-   */
-  public void deleteUnusedModelFiles() {
-    TcLog.d(TAG, "Start to delete unused model files.");
-    LocaleList localeList = LocaleList.getDefault();
-    for (@ModelTypeDef String modelType : ModelType.values()) {
-      ArraySet<ModelFile> allModelFiles = new ArraySet<>(listModelFiles(modelType));
-      for (int i = 0; i < localeList.size(); i++) {
-        // If a model file is preferred for any local in locale list, then keep it
-        ModelFile bestModel = findBestModelFile(modelType, new LocaleList(localeList.get(i)));
-        allModelFiles.remove(bestModel);
-      }
-      for (ModelFile modelFile : allModelFiles) {
-        if (modelFile.canWrite()) {
-          TcLog.d(TAG, "Deleting model: " + modelFile);
-          if (!modelFile.delete()) {
-            TcLog.w(TAG, "Failed to delete model: " + modelFile);
-          }
-        }
-      }
-    }
-  }
-
-  /** Returns the directory containing models downloaded by the downloader. */
-  public File getModelDownloaderDir() {
-    return modelDownloaderDir;
-  }
-
-  /**
    * Dumps the internal state for debugging.
    *
    * @param printWriter writer to write dumped states
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index bf326fb..3c28466 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -44,8 +44,7 @@
 import androidx.annotation.GuardedBy;
 import androidx.annotation.WorkerThread;
 import androidx.core.util.Pair;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelFile;
 import com.android.textclassifier.common.ModelType;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.base.TcLog;
diff --git a/java/src/com/android/textclassifier/common/TextClassifierSettings.java b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
index fdf259e..b6ec1c2 100644
--- a/java/src/com/android/textclassifier/common/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/common/TextClassifierSettings.java
@@ -20,6 +20,7 @@
 
 import android.provider.DeviceConfig;
 import android.provider.DeviceConfig.Properties;
+import android.text.TextUtils;
 import android.view.textclassifier.ConversationAction;
 import android.view.textclassifier.TextClassifier;
 import androidx.annotation.NonNull;
@@ -117,13 +118,19 @@
       "manifest_download_required_network_type";
   /** Max attempts allowed for a single ModelDownloader downloading task. */
   @VisibleForTesting
-  static final String MODEL_DOWNLOAD_MAX_ATTEMPTS = "model_download_max_attempts";
+  static final String MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS = "model_download_worker_max_attempts";
+  /** Max attempts allowed for a certain manifest url. */
+  @VisibleForTesting
+  static final String MANIFEST_DOWNLOAD_MAX_ATTEMPTS = "manifest_download_max_attempts";
 
   @VisibleForTesting
   static final String MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS =
       "model_download_backoff_delay_in_millis";
   /** Flag name for manifest url is dynamically formatted based on model type and model language. */
   @VisibleForTesting public static final String MANIFEST_URL_TEMPLATE = "manifest_url_%s_%s";
+
+  @VisibleForTesting public static final String MODEL_URL_BLOCKLIST = "model_url_blocklist";
+  @VisibleForTesting public static final String MODEL_URL_BLOCKLIST_SEPARATOR = ",";
   /** Sampling rate for TextClassifier API logging. */
   static final String TEXTCLASSIFIER_API_LOG_SAMPLE_RATE = "textclassifier_api_log_sample_rate";
 
@@ -195,7 +202,8 @@
   private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
   private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
   private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "UNMETERED";
-  private static final int MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 5;
+  private static final int MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS_DEFAULT = 5;
+  private static final int MANIFEST_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 2;
   private static final long MODEL_DOWNLOAD_BACKOFF_DELAY_IN_MILLIS_DEFAULT = HOURS.toMillis(1);
   private static final String MANIFEST_URL_DEFAULT = "";
   private static final float[] LANG_ID_CONTEXT_SETTINGS_DEFAULT = new float[] {20f, 1.0f, 0.4f};
@@ -380,9 +388,14 @@
         MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT);
   }
 
-  public int getModelDownloadMaxAttempts() {
+  public int getModelDownloadWorkerMaxAttempts() {
     return deviceConfig.getInt(
-        NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
+        NAMESPACE, MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS, MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS_DEFAULT);
+  }
+
+  public int getManifestDownloadMaxAttempts() {
+    return deviceConfig.getInt(
+        NAMESPACE, MANIFEST_DOWNLOAD_MAX_ATTEMPTS, MANIFEST_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
   }
 
   public long getModelDownloadBackoffDelayInMillis() {
@@ -406,12 +419,22 @@
     return deviceConfig.getString(NAMESPACE, urlFlagName, MANIFEST_URL_DEFAULT);
   }
 
+  /* Gets a list of models urls that should not be used. Usually used for a quick rollback.  */
+  public ImmutableList<String> getModelUrlBlocklist() {
+    return ImmutableList.copyOf(
+        Splitter.on(MODEL_URL_BLOCKLIST_SEPARATOR)
+            .split(deviceConfig.getString(NAMESPACE, MODEL_URL_BLOCKLIST, "")));
+  }
+
+  // TODO(licha): Let this method return a <localeTag, flagValue> map.
   /**
    * Gets all language variants configured for a specific ModelType.
    *
    * <p>For a specific language, there can be many variants: de-CH, de-LI, zh-Hans, zh-Hant. There
    * is no easy way to hardcode the list in client. Therefore, we parse all configured flag's name
    * in DeviceConfig, and let the client to choose the best variant to download.
+   *
+   * <p>If one flag's value is empty, it will be ignored.
    */
   public ImmutableList<String> getLanguageTagsForManifestURL(
       @ModelType.ModelTypeDef String modelType) {
@@ -419,8 +442,11 @@
     Properties properties = deviceConfig.getProperties(NAMESPACE);
     ImmutableList.Builder<String> variantsBuilder = ImmutableList.builder();
     for (String name : properties.getKeyset()) {
-      if (name.startsWith(urlFlagBaseName)
-          && properties.getString(name, /* defaultValue= */ null) != null) {
+      if (!name.startsWith(urlFlagBaseName)) {
+        continue;
+      }
+      String value = properties.getString(name, /* defaultValue= */ null);
+      if (!TextUtils.isEmpty(value)) {
         variantsBuilder.add(name.substring(urlFlagBaseName.length()));
       }
     }
@@ -458,7 +484,8 @@
     pw.printPair(TEMPLATE_INTENT_FACTORY_ENABLED, isTemplateIntentFactoryEnabled());
     pw.printPair(TRANSLATE_IN_CLASSIFICATION_ENABLED, isTranslateInClassificationEnabled());
     pw.printPair(MODEL_DOWNLOAD_MANAGER_ENABLED, isModelDownloadManagerEnabled());
-    pw.printPair(MODEL_DOWNLOAD_MAX_ATTEMPTS, getModelDownloadMaxAttempts());
+    pw.printPair(MODEL_DOWNLOAD_WORKER_MAX_ATTEMPTS, getModelDownloadWorkerMaxAttempts());
+    pw.printPair(MANIFEST_DOWNLOAD_MAX_ATTEMPTS, getManifestDownloadMaxAttempts());
     pw.decreaseIndent();
     pw.printPair(TEXTCLASSIFIER_API_LOG_SAMPLE_RATE, getTextClassifierApiLogSampleRate());
     pw.printPair(SESSION_ID_TO_CONTEXT_CACHE_SIZE, getSessionIdToContextCacheSize());
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java b/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java
new file mode 100644
index 0000000..0614c2f
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelDatabase.java
@@ -0,0 +1,381 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import androidx.annotation.IntDef;
+import androidx.annotation.NonNull;
+import androidx.room.ColumnInfo;
+import androidx.room.Dao;
+import androidx.room.Database;
+import androidx.room.DatabaseView;
+import androidx.room.Delete;
+import androidx.room.Embedded;
+import androidx.room.Entity;
+import androidx.room.ForeignKey;
+import androidx.room.Index;
+import androidx.room.Insert;
+import androidx.room.OnConflictStrategy;
+import androidx.room.Query;
+import androidx.room.RoomDatabase;
+import androidx.room.Transaction;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.auto.value.AutoValue;
+import com.google.common.collect.Iterables;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.util.List;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+
+/** Database storing info about models downloaded by model downloader */
+@Database(
+    entities = {
+      DownloadedModelDatabase.Model.class,
+      DownloadedModelDatabase.Manifest.class,
+      DownloadedModelDatabase.ManifestModelCrossRef.class,
+      DownloadedModelDatabase.ManifestEnrollment.class
+    },
+    views = {DownloadedModelDatabase.ModelView.class},
+    version = 1,
+    exportSchema = true)
+abstract class DownloadedModelDatabase extends RoomDatabase {
+  public static final String TAG = "DownloadedModelDatabase";
+
+  /** Rpresents a downloaded model file. */
+  @AutoValue
+  @Entity(
+      tableName = "model",
+      primaryKeys = {"model_url"})
+  abstract static class Model {
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "model_url")
+    @NonNull
+    public abstract String getModelUrl();
+
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "model_path")
+    @NonNull
+    public abstract String getModelPath();
+
+    public static Model create(String modelUrl, String modelPath) {
+      return new AutoValue_DownloadedModelDatabase_Model(modelUrl, modelPath);
+    }
+  }
+
+  /** Rpresents a manifest we processed. */
+  @AutoValue
+  @Entity(
+      tableName = "manifest",
+      primaryKeys = {"manifest_url"})
+  abstract static class Manifest {
+    // TODO(licha): Consider using Enum here
+    @Retention(RetentionPolicy.SOURCE)
+    @IntDef({STATUS_UNKNOWN, STATUS_FAILED, STATUS_SUCCEEDED})
+    @interface StatusDef {}
+
+    public static final int STATUS_UNKNOWN = 0;
+    /** Failed to download this manifest. Could be retried in the future. */
+    public static final int STATUS_FAILED = 1;
+    /** Downloaded this manifest successfully and it's currently in storage. */
+    public static final int STATUS_SUCCEEDED = 2;
+
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "manifest_url")
+    @NonNull
+    public abstract String getManifestUrl();
+
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "status")
+    @StatusDef
+    public abstract int getStatus();
+
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "failure_counts")
+    public abstract int getFailureCounts();
+
+    public static Manifest create(String manifestUrl, @StatusDef int status, int failureCounts) {
+      return new AutoValue_DownloadedModelDatabase_Manifest(manifestUrl, status, failureCounts);
+    }
+  }
+
+  /**
+   * Represents the relationship between manfiests and downloaded models.
+   *
+   * <p>A manifest can include multiple models, a model can also be included in multiple manifests.
+   * In different manifests, a model may have different configurations (e.g. primary model in
+   * manfiest A but dark model in B).
+   */
+  @AutoValue
+  @Entity(
+      tableName = "manifest_model_cross_ref",
+      primaryKeys = {"manifest_url", "model_url"},
+      foreignKeys = {
+        @ForeignKey(
+            entity = Manifest.class,
+            parentColumns = "manifest_url",
+            childColumns = "manifest_url",
+            onDelete = ForeignKey.CASCADE),
+        @ForeignKey(
+            entity = Model.class,
+            parentColumns = "model_url",
+            childColumns = "model_url",
+            onDelete = ForeignKey.CASCADE),
+      },
+      indices = {
+        @Index(value = {"manifest_url"}),
+        @Index(value = {"model_url"}),
+      })
+  abstract static class ManifestModelCrossRef {
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "manifest_url")
+    @NonNull
+    public abstract String getManifestUrl();
+
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "model_url")
+    @NonNull
+    public abstract String getModelUrl();
+
+    public static ManifestModelCrossRef create(String manifestUrl, String modelUrl) {
+      return new AutoValue_DownloadedModelDatabase_ManifestModelCrossRef(manifestUrl, modelUrl);
+    }
+  }
+
+  /**
+   * Represents the relationship between user scenarios and manifests.
+   *
+   * <p>For each unique user scenario (i.e. modelType + localTag), we store the manifest we should
+   * use. The same manifest can be used for different scenarios.
+   */
+  @AutoValue
+  @Entity(
+      tableName = "manifest_enrollment",
+      primaryKeys = {"model_type", "locale_tag"},
+      foreignKeys = {
+        @ForeignKey(
+            entity = Manifest.class,
+            parentColumns = "manifest_url",
+            childColumns = "manifest_url",
+            onDelete = ForeignKey.CASCADE)
+      },
+      indices = {@Index(value = {"manifest_url"})})
+  abstract static class ManifestEnrollment {
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "model_type")
+    @NonNull
+    @ModelTypeDef
+    public abstract String getModelType();
+
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "locale_tag")
+    @NonNull
+    public abstract String getLocaleTag();
+
+    @AutoValue.CopyAnnotations
+    @ColumnInfo(name = "manifest_url")
+    @NonNull
+    public abstract String getManifestUrl();
+
+    public static ManifestEnrollment create(
+        @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+      return new AutoValue_DownloadedModelDatabase_ManifestEnrollment(
+          modelType, localeTag, manifestUrl);
+    }
+  }
+
+  /** Represents the mapping from manfiest enrollments to models. */
+  @AutoValue
+  @DatabaseView(
+      value =
+          "SELECT manifest_enrollment.*, model.* "
+              + "FROM manifest_enrollment "
+              + "INNER JOIN manifest_model_cross_ref "
+              + "ON manifest_enrollment.manifest_url = manifest_model_cross_ref.manifest_url "
+              + "INNER JOIN model "
+              + "ON manifest_model_cross_ref.model_url = model.model_url",
+      viewName = "model_view")
+  abstract static class ModelView {
+    @AutoValue.CopyAnnotations
+    @Embedded
+    @NonNull
+    public abstract ManifestEnrollment getManifestEnrollment();
+
+    @AutoValue.CopyAnnotations
+    @Embedded
+    @NonNull
+    public abstract Model getModel();
+
+    public static ModelView create(ManifestEnrollment manifestEnrollment, Model model) {
+      return new AutoValue_DownloadedModelDatabase_ModelView(manifestEnrollment, model);
+    }
+  }
+
+  @Dao
+  abstract static class DownloadedModelDatabaseDao {
+    // Full table scan
+    @Query("SELECT * FROM model")
+    abstract List<Model> queryAllModels();
+
+    @Query("SELECT * FROM manifest")
+    abstract List<Manifest> queryAllManifests();
+
+    @Query("SELECT * FROM manifest_model_cross_ref")
+    abstract List<ManifestModelCrossRef> queryAllManifestModelCrossRefs();
+
+    @Query("SELECT * FROM manifest_enrollment")
+    abstract List<ManifestEnrollment> queryAllManifestEnrollments();
+
+    @Query("SELECT * FROM model_view")
+    abstract List<ModelView> queryAllModelViews();
+
+    // Single table query
+    @Query("SELECT * FROM model WHERE model_url = :modelUrl")
+    abstract List<Model> queryModelWithModelUrl(String modelUrl);
+
+    @Query("SELECT * FROM manifest WHERE manifest_url = :manifestUrl")
+    abstract List<Manifest> queryManifestWithManifestUrl(String manifestUrl);
+
+    @Query(
+        "SELECT * FROM manifest_enrollment WHERE model_type = :modelType "
+            + "AND locale_tag = :localeTag")
+    abstract List<ManifestEnrollment> queryManifestEnrollmentWithModelTypeAndLocaleTag(
+        @ModelTypeDef String modelType, String localeTag);
+
+    // Helpers for clean up
+    @Query(
+        "SELECT manifest.* FROM manifest "
+            + "LEFT JOIN model_view "
+            + "ON manifest.manifest_url = model_view.manifest_url "
+            + "WHERE model_view.manifest_url IS NULL "
+            + "AND manifest.status = 2")
+    abstract List<Manifest> queryUnusedManifests();
+
+    @Query(
+        "SELECT * FROM manifest WHERE manifest.status = 1 "
+            + "AND manifest.manifest_url NOT IN (:manifestUrlsToKeep)")
+    abstract List<Manifest> queryUnusedManifestFailureRecords(List<String> manifestUrlsToKeep);
+
+    @Query(
+        "SELECT model.* FROM model LEFT JOIN model_view "
+            + "ON model.model_url = model_view.model_url "
+            + "WHERE model_view.model_url IS NULL")
+    abstract List<Model> queryUnusedModels();
+
+    // Insertion
+    @Insert(onConflict = OnConflictStrategy.REPLACE)
+    abstract void insert(Model model);
+
+    @Insert(onConflict = OnConflictStrategy.REPLACE)
+    abstract void insert(Manifest manifest);
+
+    @Insert(onConflict = OnConflictStrategy.REPLACE)
+    abstract void insert(ManifestModelCrossRef manifestModelCrossRef);
+
+    @Insert(onConflict = OnConflictStrategy.REPLACE)
+    abstract void insert(ManifestEnrollment manifestEnrollment);
+
+    @Transaction
+    void insertManifestAndModelCrossRef(String manifestUrl, String modelUrl) {
+      insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
+      insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
+    }
+
+    @Transaction
+    void increaseManifestFailureCounts(String manifestUrl) {
+      List<Manifest> manifests = queryManifestWithManifestUrl(manifestUrl);
+      if (manifests.isEmpty()) {
+        insert(Manifest.create(manifestUrl, Manifest.STATUS_FAILED, /* failureCounts= */ 1));
+      } else {
+        Manifest prevManifest = Iterables.getOnlyElement(manifests);
+        insert(
+            Manifest.create(
+                manifestUrl, Manifest.STATUS_FAILED, prevManifest.getFailureCounts() + 1));
+      }
+    }
+
+    // Deletion
+    @Delete
+    abstract void deleteModels(List<Model> models);
+
+    @Delete
+    abstract void deleteManifests(List<Manifest> manifests);
+
+    @Delete
+    abstract void deleteManifestModelCrossRefs(List<ManifestModelCrossRef> manifestModelCrossRefs);
+
+    @Delete
+    abstract void deleteManifestEnrollments(List<ManifestEnrollment> manifestEnrollments);
+
+    @Transaction
+    void deleteUnusedManifestsAndModels() {
+      // Because Manifest table is the parent table of ManifestModelCrossRef table, related cross
+      // ref row in that table will be deleted automatically
+      deleteManifests(queryUnusedManifests());
+      deleteModels(queryUnusedModels());
+    }
+
+    @Transaction
+    void deleteUnusedManifestFailureRecords(List<String> manifestUrlsToKeep) {
+      deleteManifests(queryUnusedManifestFailureRecords(manifestUrlsToKeep));
+    }
+  }
+
+  abstract DownloadedModelDatabaseDao dao();
+
+  /** Dump the database for debugging. */
+  void dump(IndentingPrintWriter printWriter, ExecutorService executorService) {
+    printWriter.println("DownloadedModelDatabase");
+    printWriter.increaseIndent();
+    try {
+      printWriter.println("Model Table:");
+      printWriter.increaseIndent();
+      List<Model> models = executorService.submit(() -> dao().queryAllModels()).get();
+      for (Model model : models) {
+        printWriter.println(model.toString());
+      }
+      printWriter.decreaseIndent();
+      printWriter.println("Manifest Table:");
+      printWriter.increaseIndent();
+      List<Manifest> manifests = executorService.submit(() -> dao().queryAllManifests()).get();
+      for (Manifest manifest : manifests) {
+        printWriter.println(manifest.toString());
+      }
+      printWriter.decreaseIndent();
+      printWriter.println("ManifestModelCrossRef Table:");
+      printWriter.increaseIndent();
+      List<ManifestModelCrossRef> manifestModelCrossRefs =
+          executorService.submit(() -> dao().queryAllManifestModelCrossRefs()).get();
+      for (ManifestModelCrossRef manifestModelCrossRef : manifestModelCrossRefs) {
+        printWriter.println(manifestModelCrossRef.toString());
+      }
+      printWriter.decreaseIndent();
+      printWriter.println("ManifestEnrollment Table:");
+      printWriter.increaseIndent();
+      List<ManifestEnrollment> manifestEnrollments =
+          executorService.submit(() -> dao().queryAllManifestEnrollments()).get();
+      for (ManifestEnrollment manifestEnrollment : manifestEnrollments) {
+        printWriter.println(manifestEnrollment.toString());
+      }
+      printWriter.decreaseIndent();
+    } catch (ExecutionException | InterruptedException e) {
+      TcLog.e(TAG, "Failed to dump the database", e);
+    }
+    printWriter.decreaseIndent();
+  }
+}
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java
new file mode 100644
index 0000000..f6400fb
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManager.java
@@ -0,0 +1,137 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import java.io.File;
+import java.util.List;
+import java.util.Map;
+import javax.annotation.Nullable;
+
+// TODO(licha): Let Worker access DB class directly, then we can make this a lister interface
+/** An interface to provide easy access to DownloadedModelDatabase. */
+interface DownloadedModelManager {
+
+  /** Returns the directory containing models downloaded by the downloader. */
+  File getModelDownloaderDir();
+
+  /**
+   * Returns all downloaded model files for the given modelType
+   *
+   * <p>This method should return quickly as it may be on the critical path of serving requests.
+   *
+   * @param modelType the type of the model
+   * @return the model files. Empty if no suitable model found
+   */
+  @Nullable
+  List<File> listModels(@ModelTypeDef String modelType);
+
+  /**
+   * Returns the model entry if the model represented by the url is in our database.
+   *
+   * @param modelUrl the model url
+   * @return model entry from internal database, null if not exist
+   */
+  @Nullable
+  Model getModel(String modelUrl);
+
+  /**
+   * Returns the manifest entry if the manifest represented by the url is in our database.
+   *
+   * @param manifestUrl the manifest url
+   * @return manifest entry from internal database, null if not exist
+   */
+  @Nullable
+  Manifest getManifest(String manifestUrl);
+
+  /**
+   * Returns the manifest enrollment entry if a manifest is registered for the given type and
+   * locale.
+   *
+   * @param modelType the model type of the enrollment
+   * @param localeTag the locale tag of the enrollment
+   * @return manifest enrollment entry from internal database, null if not exist
+   */
+  @Nullable
+  ManifestEnrollment getManifestEnrollment(@ModelTypeDef String modelType, String localeTag);
+
+  /**
+   * Add a newly downloaded model to the internal database.
+   *
+   * <p>The model must be linked to a manifest via #registerManifest(). Otherwise it will be cleaned
+   * up automatically later.
+   *
+   * @param modelUrl the url where we downloaded model from
+   * @param modelPath the path where we store the downloaded model
+   */
+  void registerModel(String modelUrl, String modelPath);
+
+  /**
+   * Add a newly downloaded manifest to the internal database.
+   *
+   * <p>The manifest must be linked to a specific use case via #registerManifestEnrollment().
+   * Otherwise it will be cleaned up automatically later. Currently there is only one model in one
+   * manifest.
+   *
+   * @param manifestUrl the url where we downloaded manifest
+   * @param modelUrl the url where we downloaded the only model inside the manifest
+   */
+  void registerManifest(String manifestUrl, String modelUrl);
+
+  /**
+   * Add a failure records for the given manifest url.
+   *
+   * <p>If the manifest failed before, then increase the prevFailureCounts by one. We skip manifest
+   * if it failed too many times before.
+   *
+   * @param manifestUrl the failed manifest url
+   */
+  void registerManifestDownloadFailure(String manifestUrl);
+
+  /**
+   * Link a manifest to a specific (modelType, localeTag) use case.
+   *
+   * <p>After this registration, we will start to use this model file for all requests for the given
+   * locale and the specified model type.
+   *
+   * @param modelType the model type
+   * @param localeTag the tag of the locale on user's device that this manifest should be used for
+   * @param manifestUrl the url of the manifest
+   */
+  void registerManifestEnrollment(
+      @ModelTypeDef String modelType, String localeTag, String manifestUrl);
+
+  /**
+   * Clean up unused downloaded models and update other internal states.
+   *
+   * @param bestLocaleTagAndManifestUrls <modelType, <localeTag, manifestUrl>> the worker tried to
+   *     download
+   */
+  void onDownloadCompleted(Map<String, Pair<String, String>> bestLocaleTagAndManifestUrls);
+
+  /**
+   * Dumps the internal state for debugging.
+   *
+   * @param printWriter writer to write dumped states
+   */
+  void dump(IndentingPrintWriter printWriter);
+}
diff --git a/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
new file mode 100644
index 0000000..78559fd
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/DownloadedModelManagerImpl.java
@@ -0,0 +1,279 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.content.Context;
+import android.util.ArrayMap;
+import android.util.Pair;
+import androidx.annotation.GuardedBy;
+import androidx.room.Room;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierServiceExecutors;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
+import com.android.textclassifier.utils.IndentingPrintWriter;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+
+/** A singleton implementation of DownloadedModelManager. */
+public final class DownloadedModelManagerImpl implements DownloadedModelManager {
+  private static final String TAG = "DownloadedModelManagerImpl";
+  private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models";
+  private static final String DOWNLOADED_MODEL_DATABASE_NAME = "tcs-downloaded-model-db";
+
+  private static final Object staticLock = new Object();
+
+  @GuardedBy("staticLock")
+  private static DownloadedModelManagerImpl instance;
+
+  private final File modelDownloaderDir;
+  private final DownloadedModelDatabase db;
+  private final TextClassifierSettings settings;
+
+  private final Object cacheLock = new Object();
+
+  // modeltype -> downloaded model files
+  @GuardedBy("cacheLock")
+  private final ArrayMap<String, List<Model>> modelLookupCache;
+
+  @GuardedBy("cacheLock")
+  private boolean cacheInitialized;
+
+  @Nullable
+  public static DownloadedModelManager getInstance(Context context) {
+    synchronized (staticLock) {
+      if (instance == null) {
+        DownloadedModelDatabase db =
+            Room.databaseBuilder(
+                    context, DownloadedModelDatabase.class, DOWNLOADED_MODEL_DATABASE_NAME)
+                .build();
+        File modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+        instance =
+            new DownloadedModelManagerImpl(db, modelDownloaderDir, new TextClassifierSettings());
+      }
+      return instance;
+    }
+  }
+
+  @VisibleForTesting
+  static DownloadedModelManagerImpl getInstanceForTesting(
+      DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
+    return new DownloadedModelManagerImpl(db, modelDownloaderDir, settings);
+  }
+
+  private DownloadedModelManagerImpl(
+      DownloadedModelDatabase db, File modelDownloaderDir, TextClassifierSettings settings) {
+    this.db = db;
+    this.modelDownloaderDir = modelDownloaderDir;
+    this.modelLookupCache = new ArrayMap<>();
+    for (String modelType : ModelType.values()) {
+      this.modelLookupCache.put(modelType, new ArrayList<>());
+    }
+    this.settings = settings;
+    this.cacheInitialized = false;
+  }
+
+  @Override
+  public File getModelDownloaderDir() {
+    if (!modelDownloaderDir.exists()) {
+      modelDownloaderDir.mkdirs();
+    }
+    return modelDownloaderDir;
+  }
+
+  @Override
+  @Nullable
+  public ImmutableList<File> listModels(@ModelTypeDef String modelType) {
+    synchronized (cacheLock) {
+      if (!cacheInitialized) {
+        updateCache();
+      }
+      ImmutableList.Builder<File> builder = ImmutableList.builder();
+      ImmutableList<String> blockedModels = settings.getModelUrlBlocklist();
+      for (Model model : modelLookupCache.get(modelType)) {
+        if (blockedModels.contains(model.getModelUrl())) {
+          TcLog.d(TAG, "Model is blocklisted: " + model);
+          continue;
+        }
+        builder.add(new File(model.getModelPath()));
+      }
+      return builder.build();
+    }
+  }
+
+  @Override
+  @Nullable
+  public Model getModel(String modelUrl) {
+    List<Model> models = db.dao().queryModelWithModelUrl(modelUrl);
+    return Iterables.getFirst(models, null);
+  }
+
+  @Override
+  @Nullable
+  public Manifest getManifest(String manifestUrl) {
+    List<Manifest> manifests = db.dao().queryManifestWithManifestUrl(manifestUrl);
+    return Iterables.getFirst(manifests, null);
+  }
+
+  @Override
+  @Nullable
+  public ManifestEnrollment getManifestEnrollment(
+      @ModelTypeDef String modelType, String localeTag) {
+    List<ManifestEnrollment> manifestEnrollments =
+        db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(modelType, localeTag);
+    return Iterables.getFirst(manifestEnrollments, null);
+  }
+
+  @Override
+  public void registerModel(String modelUrl, String modelPath) {
+    db.dao().insert(Model.create(modelUrl, modelPath));
+  }
+
+  @Override
+  public void registerManifest(String manifestUrl, String modelUrl) {
+    db.dao().insertManifestAndModelCrossRef(manifestUrl, modelUrl);
+  }
+
+  @Override
+  public void registerManifestDownloadFailure(String manifestUrl) {
+    db.dao().increaseManifestFailureCounts(manifestUrl);
+  }
+
+  @Override
+  public void registerManifestEnrollment(
+      @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+    db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
+  }
+
+  @Override
+  public void dump(IndentingPrintWriter printWriter) {
+    printWriter.println("DownloadedModelManagerImpl:");
+    printWriter.increaseIndent();
+    db.dump(printWriter, TextClassifierServiceExecutors.getDownloaderExecutor());
+    printWriter.println("ModelLookupCache:");
+    synchronized (cacheLock) {
+      for (Map.Entry<String, List<Model>> entry : modelLookupCache.entrySet()) {
+        printWriter.println(entry.getKey());
+        printWriter.increaseIndent();
+        for (Model model : entry.getValue()) {
+          printWriter.println(model.toString());
+        }
+        printWriter.decreaseIndent();
+      }
+    }
+    printWriter.decreaseIndent();
+  }
+
+  @Override
+  public void onDownloadCompleted(
+      Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls) {
+    TcLog.v(TAG, "Start to clean up models and update model lookup cache...");
+    // Step 1: Clean up ManifestEnrollment table
+    List<ManifestEnrollment> allManifestEnrollments = db.dao().queryAllManifestEnrollments();
+    List<ManifestEnrollment> manifestEnrollmentsToDelete = new ArrayList<>();
+    for (String modelType : ModelType.values()) {
+      List<ManifestEnrollment> manifestEnrollmentsByType =
+          allManifestEnrollments.stream()
+              .filter(modelEnrollment -> modelEnrollment.getModelType().equals(modelType))
+              .collect(Collectors.toList());
+      Pair<String, String> localeTagAndManifestUrl =
+          modelTypeToLocaleTagAndManifestUrls.get(modelType);
+      if (localeTagAndManifestUrl == null) {
+        // No suitable manifest configured for this model type. Delete everything.
+        manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType);
+      } else {
+        String localeTag = localeTagAndManifestUrl.first;
+        String manifestUrl = localeTagAndManifestUrl.second;
+        Optional<ManifestEnrollment> optionalManifestEnrollment =
+            manifestEnrollmentsByType.stream()
+                .filter(
+                    manifestEnrollment ->
+                        manifestEnrollment.getLocaleTag().equals(localeTag)
+                            && manifestEnrollment.getManifestUrl().equals(manifestUrl))
+                .findAny();
+        if (optionalManifestEnrollment.isPresent()) {
+          // The desired manifest is downloaded successfully. Delete everything else.
+          manifestEnrollmentsToDelete.addAll(manifestEnrollmentsByType);
+          manifestEnrollmentsToDelete.remove(optionalManifestEnrollment.get());
+        } else {
+          // The desired manifest failed to be downloaded. Do not delete anything.
+          TcLog.w(
+              TAG,
+              String.format(
+                  "Desired manifest is missing on download completed: %s, %s, %s",
+                  modelType, localeTag, manifestUrl));
+        }
+      }
+    }
+    db.dao().deleteManifestEnrollments(manifestEnrollmentsToDelete);
+    // Step 2: Clean up Manifests and Models that are not linked to any ManifestEnrollment
+    db.dao().deleteUnusedManifestsAndModels();
+    // Step 3: Clean up Manifest failure records
+    // We only keep a failure record if the worker stills trys to download it
+    // We restrict the deletion to failure records only because although some manifest urls are not
+    // in allAttemptedManifestUrls, they can still be useful (e.g. current manifest is v901, and we
+    // failed to download v902. v901 will not be in the map, but it should be kept.)
+    List<String> allAttemptedManifestUrls =
+        modelTypeToLocaleTagAndManifestUrls.values().stream()
+            .map(localTagAndManifestUrlPair -> localTagAndManifestUrlPair.second)
+            .collect(Collectors.toList());
+    db.dao().deleteUnusedManifestFailureRecords(allAttemptedManifestUrls);
+    // Step 4: Update lookup cache
+    updateCache();
+    // Step 5: Clean up unused model files.
+    Set<String> modelPathsToKeep =
+        db.dao().queryAllModels().stream().map(Model::getModelPath).collect(Collectors.toSet());
+    for (File modelFile : getModelDownloaderDir().listFiles()) {
+      if (!modelPathsToKeep.contains(modelFile.getAbsolutePath())) {
+        TcLog.d(TAG, "Delete model file: " + modelFile.getAbsolutePath());
+        if (!modelFile.delete()) {
+          TcLog.e(TAG, "Failed to delete model file: " + modelFile.getAbsolutePath());
+        }
+      }
+    }
+  }
+
+  // Clear the cache table and rebuild the cache based on ModelView table
+  private void updateCache() {
+    synchronized (cacheLock) {
+      TcLog.v(TAG, "Updating model lookup cache...");
+      for (String modelType : ModelType.values()) {
+        modelLookupCache.get(modelType).clear();
+      }
+      for (ModelView modelView : db.dao().queryAllModelViews()) {
+        modelLookupCache
+            .get(modelView.getManifestEnrollment().getModelType())
+            .add(modelView.getModel());
+      }
+      cacheInitialized = true;
+    }
+  }
+}
diff --git a/java/src/com/android/textclassifier/downloader/LocaleUtils.java b/java/src/com/android/textclassifier/downloader/LocaleUtils.java
new file mode 100644
index 0000000..a6791a3
--- /dev/null
+++ b/java/src/com/android/textclassifier/downloader/LocaleUtils.java
@@ -0,0 +1,74 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import android.text.TextUtils;
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.google.common.annotations.VisibleForTesting;
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import javax.annotation.Nullable;
+
+/** Utilities for locale matching. */
+final class LocaleUtils {
+  @VisibleForTesting static final String UNIVERSAL_LOCALE_TAG = "universal";
+
+  /**
+   * Find the best locale tag as well as the configured manfiest url from device config.
+   *
+   * @param modelType the model type
+   * @param targetLocale target locale
+   * @param settings TextClassifierSettings to check device config
+   * @return a pair of <bestLocaleTag, manfiestUrl>. Null if not found.
+   */
+  @Nullable
+  static Pair<String, String> lookupBestLocaleTagAndManifestUrl(
+      @ModelTypeDef String modelType, Locale targetLocale, TextClassifierSettings settings) {
+    List<String> allLocaleTags = settings.getLanguageTagsForManifestURL(modelType);
+    String bestLocaleTag = lookupBestLocaleTag(targetLocale, allLocaleTags);
+    if (bestLocaleTag == null) {
+      return null;
+    }
+    String manifestUrl = settings.getManifestURL(modelType, bestLocaleTag);
+    if (TextUtils.isEmpty(manifestUrl)) {
+      return null;
+    }
+    return Pair.create(bestLocaleTag, manifestUrl);
+  }
+
+  /** Find the best locale tag for the target locale. Return null if no one is suitable. */
+  @Nullable
+  static String lookupBestLocaleTag(Locale targetLocale, Collection<String> availableTags) {
+    // Notice: this lookup API just trys to match the longest prefix for the target locale tag.
+    // Its implementation looks inefficient and the behavior may not be 100% desired. E.g. if the
+    // target locale is en, and we only have en-uk in available tags, the current API returns null.
+    String bestTag =
+        Locale.lookupTag(Locale.LanguageRange.parse(targetLocale.toLanguageTag()), availableTags);
+    if (bestTag != null) {
+      return bestTag;
+    }
+    if (availableTags.contains(UNIVERSAL_LOCALE_TAG)) {
+      return UNIVERSAL_LOCALE_TAG;
+    }
+    return null;
+  }
+
+  private LocaleUtils() {}
+}
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
index 2629c20..78d26fa 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloadManager.java
@@ -22,50 +22,39 @@
 import android.content.Context;
 import android.content.Intent;
 import android.content.IntentFilter;
-import android.os.LocaleList;
 import android.provider.DeviceConfig;
-import android.text.TextUtils;
 import androidx.work.BackoffPolicy;
 import androidx.work.Constraints;
 import androidx.work.ExistingWorkPolicy;
 import androidx.work.ListenableWorker;
 import androidx.work.NetworkType;
 import androidx.work.OneTimeWorkRequest;
-import androidx.work.WorkInfo;
+import androidx.work.Operation;
 import androidx.work.WorkManager;
-import androidx.work.WorkQuery;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.utils.IndentingPrintWriter;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Enums;
 import com.google.common.base.Preconditions;
-import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListeningExecutorService;
-import java.time.Duration;
+import java.io.File;
 import java.util.List;
-import java.util.Locale;
-import java.util.concurrent.ExecutionException;
+import javax.annotation.Nullable;
 
 /** Manager to listen to config update and download latest models. */
 public final class ModelDownloadManager {
   private static final String TAG = "ModelDownloadManager";
 
-  static final String UNIVERSAL_LOCALE_TAG = "universal";
-
-  public static final int CHECK_REASON_TCS_ON_CREATE = 0;
-  public static final int CHECK_REASON_DEVICE_CONFIG_UPDATED = 1;
-  public static final int CHECK_REASON_LOCALE_UPDATED = 2;
-
-  // Keep the records forever (100 Years). We use the record to skip previously failed tasks.
-  static final long DAYS_TO_KEEP_THE_DOWNLOAD_RESULT = 365000L;
-
-  private final Object lock = new Object();
+  @VisibleForTesting static final String UNIQUE_QUEUE_NAME = "ModelDownloadWorkManagerQueue";
 
   private final Context appContext;
   private final Class<? extends ListenableWorker> modelDownloadWorkerClass;
-  private final ModelFileManager modelFileManager;
+  private final DownloadedModelManager downloadedModelManager;
   private final TextClassifierSettings settings;
   private final ListeningExecutorService executorService;
   private final DeviceConfig.OnPropertiesChangedListener deviceConfigListener;
@@ -75,28 +64,31 @@
    * Constructor for ModelDownloadManager.
    *
    * @param appContext the context of this application
-   * @param modelFileManager ModelFileManager to interact with storage layer
    * @param settings TextClassifierSettings to access DeviceConfig and other settings
    * @param executorService background executor service
    */
   public ModelDownloadManager(
       Context appContext,
-      ModelFileManager modelFileManager,
       TextClassifierSettings settings,
       ListeningExecutorService executorService) {
-    this(appContext, NewModelDownloadWorker.class, modelFileManager, settings, executorService);
+    this(
+        appContext,
+        NewModelDownloadWorker.class,
+        DownloadedModelManagerImpl.getInstance(appContext),
+        settings,
+        executorService);
   }
 
   @VisibleForTesting
   ModelDownloadManager(
       Context appContext,
       Class<? extends ListenableWorker> modelDownloadWorkerClass,
-      ModelFileManager modelFileManager,
+      DownloadedModelManager downloadedModelManager,
       TextClassifierSettings settings,
       ListeningExecutorService executorService) {
     this.appContext = Preconditions.checkNotNull(appContext);
     this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass);
-    this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
+    this.downloadedModelManager = Preconditions.checkNotNull(downloadedModelManager);
     this.settings = Preconditions.checkNotNull(settings);
     this.executorService = Preconditions.checkNotNull(executorService);
 
@@ -104,8 +96,6 @@
         new DeviceConfig.OnPropertiesChangedListener() {
           @Override
           public void onPropertiesChanged(DeviceConfig.Properties unused) {
-            // We will only be notified for changes in our package. Trigger the check even when the
-            // change is unrelated just in case we missed a previous update.
             onTextClassifierDeviceConfigChanged();
           }
         };
@@ -118,6 +108,12 @@
         };
   }
 
+  /** Returns the downlaoded models for the given modelType. */
+  @Nullable
+  public List<File> listDownloadedModels(@ModelTypeDef String modelType) {
+    return downloadedModelManager.listModels(modelType);
+  }
+
   /** Notifies the model downlaoder that the text classifier service is created. */
   public void onTextClassifierServiceCreated() {
     DeviceConfig.addOnPropertiesChangedListener(
@@ -128,12 +124,8 @@
     if (!settings.isModelDownloadManagerEnabled()) {
       return;
     }
-    executorService.execute(
-        () -> {
-          TcLog.d(TAG, "Checking downloader flags because of the start of TextClassifierService.");
-          modelFileManager.deleteUnusedModelFiles();
-          checkConfigAndScheduleDownloads();
-        });
+    TcLog.v(TAG, "Try to schedule model download work because TextClassifierService started.");
+    scheduleDownloadWork();
   }
 
   // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -143,12 +135,8 @@
     if (!settings.isModelDownloadManagerEnabled()) {
       return;
     }
-    executorService.execute(
-        () -> {
-          TcLog.d(TAG, "Checking downloader flags because of locale changes.");
-          modelFileManager.deleteUnusedModelFiles();
-          checkConfigAndScheduleDownloads();
-        });
+    TcLog.v(TAG, "Try to schedule model download work because of system locale changes.");
+    scheduleDownloadWork();
   }
 
   // TODO(licha): Make this private. Let the constructor accept a receiver to enable testing.
@@ -158,11 +146,8 @@
     if (!settings.isModelDownloadManagerEnabled()) {
       return;
     }
-    executorService.execute(
-        () -> {
-          TcLog.d(TAG, "Checking downloader flags because of device config changes.");
-          checkConfigAndScheduleDownloads();
-        });
+    TcLog.v(TAG, "Try to schedule model download work because of device config changes.");
+    scheduleDownloadWork();
   }
 
   /** Clean up internal states on destroying. */
@@ -173,96 +158,58 @@
   }
 
   /**
-   * Check DeviceConfig and schedule new model download requests synchronously. This method is
-   * synchronized and contains blocking operations, only call it in a background executor.
+   * Dumps the internal state for debugging.
+   *
+   * @param printWriter writer to write dumped states
    */
-  private void checkConfigAndScheduleDownloads() {
-    synchronized (lock) {
-      WorkManager workManager = WorkManager.getInstance(appContext);
-      List<Locale.LanguageRange> languageRanges =
-          Locale.LanguageRange.parse(LocaleList.getAdjustedDefault().toLanguageTags());
-      for (String modelType : ModelType.values()) {
-        // Notice: Be careful of the Locale.lookupTag() matching logic: 1) it will convert the tag
-        // to lower-case only; 2) it matches tags from tail to head and does not allow missing
-        // pieces. E.g. if your system locale is zh-hans-cn, it won't match zh-cn.
-        String bestTag =
-            Locale.lookupTag(languageRanges, settings.getLanguageTagsForManifestURL(modelType));
-        String localeTag = bestTag != null ? bestTag : UNIVERSAL_LOCALE_TAG;
-        TcLog.v(TAG, String.format("Checking model type: %s, best tag: %s", modelType, localeTag));
-
-        // One manifest url can uniquely identify a model in the world
-        String manifestUrl = settings.getManifestURL(modelType, localeTag);
-        if (TextUtils.isEmpty(manifestUrl)) {
-          continue;
-        }
-
-        // Query the history for this url. Stop if we handled this url before.
-        // TODO(licha): We may skip downloads incorrectly if we switch locales back and forth.
-        WorkQuery workQuery = WorkQuery.Builder.fromTags(ImmutableList.of(manifestUrl)).build();
-        try {
-          List<WorkInfo> workInfos = workManager.getWorkInfos(workQuery).get();
-          if (!workInfos.isEmpty()) {
-            TcLog.v(
-                TAG,
-                String.format(
-                    "Target manifest has an existing state of: %s. Skip.",
-                    workInfos.get(0).getState().name()));
-            continue;
-          }
-        } catch (ExecutionException | InterruptedException e) {
-          TcLog.e(TAG, "Failed to query queued requests. Ignore and continue.", e);
-        }
-
-        NetworkType networkType =
-            Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
-                .or(NetworkType.UNMETERED);
-        OneTimeWorkRequest downloadRequest =
-            new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
-                .setInputData(
-                    NewModelDownloadWorker.createInputData(
-                        modelType, localeTag, manifestUrl, settings.getModelDownloadMaxAttempts()))
-                .addTag(manifestUrl)
-                .setConstraints(
-                    new Constraints.Builder()
-                        .setRequiredNetworkType(networkType)
-                        .setRequiresBatteryNotLow(true)
-                        .setRequiresStorageNotLow(true)
-                        .build())
-                .setBackoffCriteria(
-                    BackoffPolicy.EXPONENTIAL,
-                    settings.getModelDownloadBackoffDelayInMillis(),
-                    MILLISECONDS)
-                .keepResultsForAtLeast(Duration.ofDays(DAYS_TO_KEEP_THE_DOWNLOAD_RESULT))
-                .build();
-
-        // When we enqueue a new request, existing pending request in the same queue will be
-        // cancelled. With this, device will be able to abort previous unfinished downloads
-        // (e.g. 711) when a fresher model is already(e.g. v712).
-        try {
-          // Block until we enqueue the request successfully (to avoid potential race condition)
-          workManager
-              .enqueueUniqueWork(
-                  formatUniqueWorkName(modelType, localeTag),
-                  ExistingWorkPolicy.REPLACE,
-                  downloadRequest)
-              .getResult()
-              .get();
-          TextClassifierDownloadLogger.downloadSceduled(modelType, manifestUrl);
-          TcLog.d(TAG, "Download scheduled: " + manifestUrl);
-        } catch (ExecutionException | InterruptedException e) {
-          TextClassifierDownloadLogger.downloadFailedAndAbort(
-              modelType,
-              manifestUrl,
-              ModelDownloadException.FAILED_TO_SCHEDULE,
-              /* runAttemptCount= */ 0);
-          TcLog.e(TAG, "Failed to enqueue a request", e);
-        }
-      }
-    }
+  public void dump(IndentingPrintWriter printWriter) {
+    printWriter.println("ModelDownloadManager:");
+    printWriter.increaseIndent();
+    downloadedModelManager.dump(printWriter);
+    printWriter.decreaseIndent();
   }
 
-  /** Formats unique work name for WorkManager. */
-  static String formatUniqueWorkName(@ModelType.ModelTypeDef String modelType, String localeTag) {
-    return String.format("%s-%s", modelType, localeTag);
+  /**
+   * Enqueue an idempotent work to check device configs and download model files if necessary.
+   *
+   * <p>At any time there will only be at most one work running. If a work is already pending or
+   * running, the newly scheduled work will be appended as a child of that work.
+   */
+  private void scheduleDownloadWork() {
+    NetworkType networkType =
+        Enums.getIfPresent(NetworkType.class, settings.getManifestDownloadRequiredNetworkType())
+            .or(NetworkType.UNMETERED);
+    OneTimeWorkRequest downloadRequest =
+        new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
+            .setConstraints(
+                new Constraints.Builder()
+                    .setRequiredNetworkType(networkType)
+                    .setRequiresBatteryNotLow(true)
+                    .setRequiresStorageNotLow(true)
+                    .build())
+            .setBackoffCriteria(
+                BackoffPolicy.EXPONENTIAL,
+                settings.getModelDownloadBackoffDelayInMillis(),
+                MILLISECONDS)
+            .build();
+    ListenableFuture<Operation.State.SUCCESS> enqueueResultFuture =
+        WorkManager.getInstance(appContext)
+            .enqueueUniqueWork(
+                UNIQUE_QUEUE_NAME, ExistingWorkPolicy.APPEND_OR_REPLACE, downloadRequest)
+            .getResult();
+    Futures.addCallback(
+        enqueueResultFuture,
+        new FutureCallback<Operation.State.SUCCESS>() {
+          @Override
+          public void onSuccess(Operation.State.SUCCESS unused) {
+            TcLog.v(TAG, "Download work scheduled.");
+          }
+
+          @Override
+          public void onFailure(Throwable t) {
+            TcLog.e(TAG, "Failed to schedule download work: ", t);
+          }
+        },
+        executorService);
   }
 }
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloader.java b/java/src/com/android/textclassifier/downloader/ModelDownloader.java
index 7c0c32d..7e22d99 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloader.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloader.java
@@ -35,12 +35,13 @@
   /**
    * Downloads a model file and validate it based on given model info.
    *
-   * <p>The file should be in the cache folder. Returns the File if succeed. If the download or
+   * <p>The file should be in the target folder. Returns the File if succeed. If the download or
    * validation fails, the unfinished model file should be cleaned up. Failures should be wrapped
    * inside a {@link ModelDownloadException} and throw.
    *
+   * @param targetDir the target directory for the downloaded model
    * @param modelInfo the model information in manifest used for downloading and validation
    * @return the downloaded model file
    */
-  ListenableFuture<File> downloadModel(ModelManifest.Model modelInfo);
+  ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model modelInfo);
 }
diff --git a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
index 564cf67..366364f 100644
--- a/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
+++ b/java/src/com/android/textclassifier/downloader/ModelDownloaderImpl.java
@@ -84,9 +84,8 @@
   }
 
   @Override
-  public ListenableFuture<File> downloadModel(ModelManifest.Model model) {
-    File modelFile =
-        new File(context.getCacheDir(), model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model");
+  public ListenableFuture<File> downloadModel(File targetDir, ModelManifest.Model model) {
+    File modelFile = new File(targetDir, model.getUrl().replaceAll("[^A-Za-z0-9]", "_") + ".model");
     ListenableFuture<File> modelFileFuture =
         Futures.transform(
             download(URI.create(model.getUrl()), modelFile),
@@ -100,7 +99,7 @@
         new FutureCallback<File>() {
           @Override
           public void onSuccess(File pendingModelFile) {
-            TcLog.d(TAG, "Download mode successfully: " + pendingModelFile.getAbsolutePath());
+            TcLog.v(TAG, "Download model successfully: " + pendingModelFile.getAbsolutePath());
           }
 
           @Override
diff --git a/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java b/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java
index 0388165..436bfe4 100644
--- a/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java
+++ b/java/src/com/android/textclassifier/downloader/NewModelDownloadWorker.java
@@ -17,167 +17,269 @@
 package com.android.textclassifier.downloader;
 
 import android.content.Context;
-import androidx.work.Data;
+import android.util.ArrayMap;
+import android.util.Pair;
 import androidx.work.ListenableWorker;
 import androidx.work.WorkerParameters;
-import com.android.textclassifier.common.ModelFileManager;
 import com.android.textclassifier.common.ModelType;
 import com.android.textclassifier.common.ModelType.ModelTypeDef;
 import com.android.textclassifier.common.TextClassifierServiceExecutors;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.util.concurrent.FluentFuture;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
-import java.io.File;
-import java.nio.file.Files;
-import java.nio.file.StandardCopyOption;
-import java.util.concurrent.ExecutorService;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.errorprone.annotations.concurrent.GuardedBy;
+import java.util.ArrayList;
+import java.util.Locale;
 
 // TODO(licha): Rename this to ModelDownloadWorker.
-// TODO(licha): Consider whether we should let the worker handle all locales/model types.
 /** The WorkManager worker to download models for TextClassifierService. */
 public final class NewModelDownloadWorker extends ListenableWorker {
   private static final String TAG = "NewModelDownloadWorker";
-  private static final int MAX_DOWNLOAD_ATTEMPTS_DEFAULT = 5;
 
-  static final String DATA_MODEL_TYPE_KEY = "NewDownloadWorker_modelType";
-  static final String DATA_LOCALE_TAG_KEY = "NewDownloadWorker_localeTag";
-  static final String DATA_MANIFEST_URL_KEY = "NewDownloadWorker_manifestUrl";
-  static final String DATA_MAX_DOWNLOAD_ATTEMPTS_KEY = "NewDownloadWorker_maxDownloadAttempts";
-
-  @ModelTypeDef private final String modelType;
-  private final String manifestUrl;
-  private final int maxDownloadAttempts;
-
-  private final ExecutorService executorService;
+  private final ListeningExecutorService executorService;
   private final ModelDownloader downloader;
-  private final File modelDownloaderDir;
-  private final Runnable postDownloadCleanUpRunnable;
+  private final DownloadedModelManager downloadedModelManager;
+  private final TextClassifierSettings settings;
+
+  private final Object lock = new Object();
+
+  @GuardedBy("lock")
+  private final ArrayMap<String, ListenableFuture<Void>> pendingDownloads;
+
+  private ImmutableMap<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls;
 
   public NewModelDownloadWorker(Context context, WorkerParameters workerParams) {
     super(context, workerParams);
-
-    this.modelType = Preconditions.checkNotNull(getInputData().getString(DATA_MODEL_TYPE_KEY));
-    this.manifestUrl = Preconditions.checkNotNull(getInputData().getString(DATA_MANIFEST_URL_KEY));
-    this.maxDownloadAttempts =
-        getInputData().getInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, MAX_DOWNLOAD_ATTEMPTS_DEFAULT);
-
     this.executorService = TextClassifierServiceExecutors.getDownloaderExecutor();
     this.downloader = new ModelDownloaderImpl(context, executorService);
-    ModelFileManager modelFileManager = new ModelFileManager(context, new TextClassifierSettings());
-    this.modelDownloaderDir = modelFileManager.getModelDownloaderDir();
-    this.postDownloadCleanUpRunnable = modelFileManager::deleteUnusedModelFiles;
+    this.downloadedModelManager = DownloadedModelManagerImpl.getInstance(context);
+    this.settings = new TextClassifierSettings();
+    this.pendingDownloads = new ArrayMap<>();
+    this.modelTypeToLocaleTagAndManifestUrls = null;
   }
 
   @VisibleForTesting
   NewModelDownloadWorker(
       Context context,
       WorkerParameters workerParams,
-      ExecutorService executorService,
+      ListeningExecutorService executorService,
       ModelDownloader modelDownloader,
-      File modelDownloaderDir,
-      Runnable postDownloadCleanUpRunnable) {
+      DownloadedModelManager downloadedModelManager,
+      TextClassifierSettings settings) {
     super(context, workerParams);
-
-    this.modelType = Preconditions.checkNotNull(getInputData().getString(DATA_MODEL_TYPE_KEY));
-    this.manifestUrl = Preconditions.checkNotNull(getInputData().getString(DATA_MANIFEST_URL_KEY));
-    this.maxDownloadAttempts =
-        getInputData().getInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, MAX_DOWNLOAD_ATTEMPTS_DEFAULT);
-
     this.executorService = executorService;
     this.downloader = modelDownloader;
-    this.modelDownloaderDir = modelDownloaderDir;
-    this.postDownloadCleanUpRunnable = postDownloadCleanUpRunnable;
+    this.downloadedModelManager = downloadedModelManager;
+    this.settings = settings;
+    this.pendingDownloads = new ArrayMap<>();
+    this.modelTypeToLocaleTagAndManifestUrls = null;
   }
 
   @Override
   public final ListenableFuture<ListenableWorker.Result> startWork() {
-    if (getRunAttemptCount() >= maxDownloadAttempts) {
-      TcLog.d(TAG, "Max attempt reached. Abort download task.");
-      TextClassifierDownloadLogger.downloadFailedAndAbort(
-          modelType,
-          manifestUrl,
-          // TODO(licha): Add a new failure reason for this
-          ModelDownloadException.UNKNOWN_FAILURE_REASON,
-          getRunAttemptCount());
+    // Notice: startWork() is invoked on the main thread
+    if (!settings.isModelDownloadManagerEnabled()) {
+      TcLog.e(TAG, "Model Downloader is disabled. Abort the work.");
       return Futures.immediateFuture(ListenableWorker.Result.failure());
     }
-    FluentFuture<ListenableWorker.Result> resultFuture =
-        FluentFuture.from(downloader.downloadManifest(manifestUrl))
-            .transformAsync(
-                manifest -> {
-                  // TODO(licha): put this in a function to improve the readability
-                  ModelManifest.Model modelInfo = manifest.getModels(0);
-                  File targetModelFile =
-                      new File(
-                          modelDownloaderDir,
-                          formatFileNameByModelTypeAndUrl(modelType, modelInfo.getUrl()));
-                  if (targetModelFile.exists()) {
-                    TcLog.d(
-                        TAG,
-                        "Target model file already exists. Skip download and reuse it: "
-                            + targetModelFile.getAbsolutePath());
-                    TextClassifierDownloadLogger.downloadSucceeded(
-                        modelType, manifestUrl, getRunAttemptCount());
-                    return Futures.immediateFuture(ListenableWorker.Result.success());
-                  } else {
-                    return Futures.transform(
-                        downloadAndMoveModel(targetModelFile, modelInfo),
-                        unused -> {
-                          TextClassifierDownloadLogger.downloadSucceeded(
-                              modelType, manifestUrl, getRunAttemptCount());
-                          return ListenableWorker.Result.success();
-                        },
-                        executorService);
-                  }
-                },
-                executorService)
-            .catching(
-                Throwable.class,
-                e -> {
-                  TcLog.e(TAG, "Download attempt failed.", e);
-                  int errorCode =
-                      (e instanceof ModelDownloadException)
-                          ? ((ModelDownloadException) e).getErrorCode()
-                          : ModelDownloadException.UNKNOWN_FAILURE_REASON;
-                  // Retry until reach max allowed attempts (attempt starts from 0)
-                  // The backoff time between two tries will grow exponentially (i.e. 30s, 1min,
-                  // 2min, 4min). This is configurable when building the request.
-                  TextClassifierDownloadLogger.downloadFailedAndRetry(
-                      modelType, manifestUrl, errorCode, getRunAttemptCount());
-                  return ListenableWorker.Result.retry();
-                },
-                executorService);
-    resultFuture.addListener(postDownloadCleanUpRunnable, executorService);
-    return resultFuture;
+    TcLog.v(TAG, "Start download work...");
+    if (getRunAttemptCount() >= settings.getModelDownloadWorkerMaxAttempts()) {
+      TcLog.d(TAG, "Max attempt reached. Abort download work.");
+      return Futures.immediateFuture(ListenableWorker.Result.failure());
+    }
+
+    return FluentFuture.from(Futures.submitAsync(this::checkAndDownloadModels, executorService))
+        .transform(
+            allSucceeded -> {
+              Preconditions.checkNotNull(modelTypeToLocaleTagAndManifestUrls);
+              downloadedModelManager.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+              TcLog.v(TAG, "Download work completed. Succeeded: " + allSucceeded);
+              return allSucceeded
+                  ? ListenableWorker.Result.success()
+                  : ListenableWorker.Result.retry();
+            },
+            executorService)
+        .catching(
+            Throwable.class,
+            t -> {
+              TcLog.e(TAG, "Unexpected Exception during downloading: ", t);
+              return ListenableWorker.Result.retry();
+            },
+            executorService);
   }
 
-  private ListenableFuture<Void> downloadAndMoveModel(
-      File targetModelFile, ModelManifest.Model modelInfo) {
-    return Futures.transform(
-        downloader.downloadModel(modelInfo),
-        pendingModelFile -> {
-          try {
-            if (!modelDownloaderDir.exists()) {
-              modelDownloaderDir.mkdirs();
-            }
-            Files.move(
-                pendingModelFile.toPath(),
-                targetModelFile.toPath(),
-                StandardCopyOption.ATOMIC_MOVE,
-                StandardCopyOption.REPLACE_EXISTING);
-            TcLog.d(
-                TAG, "Model file downloaded successfully: " + targetModelFile.getAbsolutePath());
-            return null;
-          } catch (Exception e) {
-            pendingModelFile.delete();
-            throw new ModelDownloadException(ModelDownloadException.FAILED_TO_MOVE_MODEL, e);
-          }
-        },
-        executorService);
+  /**
+   * Check device config and dispatch download tasks for all modelTypes.
+   *
+   * <p>Download tasks will be combined and logged after completion. Return true if all tasks
+   * succeeded
+   */
+  private ListenableFuture<Boolean> checkAndDownloadModels() {
+    Locale primaryLocale = Locale.getDefault();
+    ArrayList<ListenableFuture<Boolean>> downloadResultFutures = new ArrayList<>();
+    ImmutableMap.Builder<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrlsBuilder =
+        ImmutableMap.builder();
+    for (String modelType : ModelType.values()) {
+      Pair<String, String> bestLocaleTagAndManifestUrl =
+          LocaleUtils.lookupBestLocaleTagAndManifestUrl(modelType, primaryLocale, settings);
+      if (bestLocaleTagAndManifestUrl == null) {
+        TcLog.w(
+            TAG,
+            String.format(
+                "No suitable manifest for %s, %s", modelType, primaryLocale.toLanguageTag()));
+        continue;
+      }
+      modelTypeToLocaleTagAndManifestUrlsBuilder.put(modelType, bestLocaleTagAndManifestUrl);
+      String bestLocaleTag = bestLocaleTagAndManifestUrl.first;
+      String manifestUrl = bestLocaleTagAndManifestUrl.second;
+      TcLog.v(
+          TAG,
+          String.format(
+              "model type: %s, device locale tag: %s, best locale tag: %s, manifest url: %s",
+              modelType, primaryLocale.toLanguageTag(), bestLocaleTag, manifestUrl));
+      if (!shouldDownloadManifest(modelType, bestLocaleTag, manifestUrl)) {
+        continue;
+      }
+      downloadResultFutures.add(downloadManifestAndRegister(modelType, bestLocaleTag, manifestUrl));
+    }
+    modelTypeToLocaleTagAndManifestUrls = modelTypeToLocaleTagAndManifestUrlsBuilder.build();
+
+    return Futures.whenAllComplete(downloadResultFutures)
+        .call(
+            () -> {
+              TcLog.v(TAG, "All Download Tasks Completed");
+              boolean allSucceeded = true;
+              for (ListenableFuture<Boolean> downloadResultFuture : downloadResultFutures) {
+                allSucceeded &= Futures.getDone(downloadResultFuture);
+              }
+              return allSucceeded;
+            },
+            executorService);
+  }
+
+  private boolean shouldDownloadManifest(
+      @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+    Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
+    if (downloadedManifest == null) {
+      return true;
+    }
+    if (downloadedManifest.getStatus() == Manifest.STATUS_FAILED) {
+      if (downloadedManifest.getFailureCounts() >= settings.getManifestDownloadMaxAttempts()) {
+        TcLog.w(
+            TAG,
+            String.format(
+                "Manifest failed too many times, stop retrying: %s %d",
+                manifestUrl, downloadedManifest.getFailureCounts()));
+        return false;
+      } else {
+        return true;
+      }
+    }
+    ManifestEnrollment manifestEnrollment =
+        downloadedModelManager.getManifestEnrollment(modelType, localeTag);
+    return manifestEnrollment == null || !manifestUrl.equals(manifestEnrollment.getManifestUrl());
+  }
+
+  /**
+   * Downloads a single manifest and models configured inside it.
+   *
+   * <p>The returned future should always resolve to a ManifestDownloadResult as we catch all
+   * exceptions.
+   */
+  private ListenableFuture<Boolean> downloadManifestAndRegister(
+      @ModelTypeDef String modelType, String localeTag, String manifestUrl) {
+    return FluentFuture.from(downloadManifest(manifestUrl))
+        .transform(
+            unused -> {
+              downloadedModelManager.registerManifestEnrollment(modelType, localeTag, manifestUrl);
+              TextClassifierDownloadLogger.downloadSucceeded(
+                  modelType, manifestUrl, getRunAttemptCount());
+              TcLog.v(TAG, "Manifest downloaded and registered: " + manifestUrl);
+              return true;
+            },
+            executorService)
+        .catching(
+            Throwable.class,
+            t -> {
+              downloadedModelManager.registerManifestDownloadFailure(manifestUrl);
+              int errorCode = ModelDownloadException.UNKNOWN_FAILURE_REASON;
+              if (t instanceof ModelDownloadException) {
+                errorCode = ((ModelDownloadException) t).getErrorCode();
+              }
+              TcLog.e(TAG, "Failed to download manfiest: " + manifestUrl, t);
+              TextClassifierDownloadLogger.downloadFailedAndRetry(
+                  modelType, manifestUrl, errorCode, getRunAttemptCount());
+              return false;
+            },
+            executorService);
+  }
+
+  // Download a manifest and its models, and register it to Manifest table.
+  private ListenableFuture<Void> downloadManifest(String manifestUrl) {
+    synchronized (lock) {
+      Manifest downloadedManifest = downloadedModelManager.getManifest(manifestUrl);
+      if (downloadedManifest != null
+          && downloadedManifest.getStatus() == Manifest.STATUS_SUCCEEDED) {
+        TcLog.v(TAG, "Manifest already downloaded: " + manifestUrl);
+        return Futures.immediateVoidFuture();
+      }
+      if (pendingDownloads.containsKey(manifestUrl)) {
+        return pendingDownloads.get(manifestUrl);
+      }
+      ListenableFuture<Void> manfiestDownloadFuture =
+          FluentFuture.from(downloader.downloadManifest(manifestUrl))
+              .transformAsync(
+                  manifest -> {
+                    ModelManifest.Model modelInfo = manifest.getModels(0);
+                    return Futures.transform(
+                        downloadModel(modelInfo), unused -> modelInfo, executorService);
+                  },
+                  executorService)
+              .transform(
+                  modelInfo -> {
+                    downloadedModelManager.registerManifest(manifestUrl, modelInfo.getUrl());
+                    return null;
+                  },
+                  executorService);
+      pendingDownloads.put(manifestUrl, manfiestDownloadFuture);
+      return manfiestDownloadFuture;
+    }
+  }
+  // Download a model and register it into Model table.
+  private ListenableFuture<Void> downloadModel(ModelManifest.Model modelInfo) {
+    String modelUrl = modelInfo.getUrl();
+    synchronized (lock) {
+      Model downloadedModel = downloadedModelManager.getModel(modelUrl);
+      if (downloadedModel != null) {
+        TcLog.d(TAG, "Model file already exists: " + downloadedModel.getModelPath());
+        return Futures.immediateVoidFuture();
+      }
+      if (pendingDownloads.containsKey(modelUrl)) {
+        return pendingDownloads.get(modelUrl);
+      }
+      ListenableFuture<Void> modelDownloadFuture =
+          FluentFuture.from(
+                  downloader.downloadModel(
+                      downloadedModelManager.getModelDownloaderDir(), modelInfo))
+              .transform(
+                  modelFile -> {
+                    downloadedModelManager.registerModel(modelUrl, modelFile.getAbsolutePath());
+                    TcLog.v(TAG, "Model File downloaded: " + modelUrl);
+                    return null;
+                  },
+                  executorService);
+      pendingDownloads.put(modelUrl, modelDownloadFuture);
+      return modelDownloadFuture;
+    }
   }
 
   /**
@@ -187,39 +289,6 @@
    */
   @Override
   public final void onStopped() {
-    TcLog.d(TAG, String.format("Stop download: %s, attempt:%d", manifestUrl, getRunAttemptCount()));
-    TextClassifierDownloadLogger.downloadFailedAndRetry(
-        modelType, manifestUrl, ModelDownloadException.WORKER_STOPPED, getRunAttemptCount());
-  }
-
-  static final Data createInputData(
-      @ModelTypeDef String modelType,
-      String localeTag,
-      String manifestUrl,
-      int maxDownloadAttempts) {
-    return new Data.Builder()
-        .putString(DATA_MODEL_TYPE_KEY, modelType)
-        .putString(DATA_LOCALE_TAG_KEY, localeTag)
-        .putString(DATA_MANIFEST_URL_KEY, manifestUrl)
-        .putInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, maxDownloadAttempts)
-        .build();
-  }
-
-  /**
-   * Returns the absolute path to download a model.
-   *
-   * <p>Each file's name is uniquely formatted based on its unique remote manifest URL.
-   *
-   * @param modelType the type of the model image to download
-   * @param url the unique remote url
-   */
-  static String formatFileNameByModelTypeAndUrl(
-      @ModelType.ModelTypeDef String modelType, String url) {
-    // TODO(licha): Consider preserving the folder hierarchy of the URL
-    String fileMidName = url.replaceAll("[^A-Za-z0-9]", "_");
-    if (fileMidName.startsWith("https___")) {
-      fileMidName = fileMidName.substring("https___".length());
-    }
-    return String.format("%s.%s.model", modelType, fileMidName);
+    TcLog.d(TAG, String.format("Stop download. Attempt:%d", getRunAttemptCount()));
   }
 }
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 746931b..ad4a197 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -42,7 +42,6 @@
 import com.android.os.AtomsProto.TextClassifierApiUsageReported;
 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ApiType;
 import com.android.os.AtomsProto.TextClassifierApiUsageReported.ResultType;
-import com.android.textclassifier.common.ModelFileManager;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.statsd.StatsdTestUtils;
 import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
new file mode 100644
index 0000000..3c0319f
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -0,0 +1,394 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.android.textclassifier.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
+import static com.google.common.truth.Truth.assertThat;
+
+import android.os.LocaleList;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManager.RegularFilePatternMatchLister;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.google.common.base.Optional;
+import com.google.common.collect.ImmutableList;
+import com.google.common.io.Files;
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public final class ModelFileManagerTest {
+  private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
+
+  @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+  @Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
+
+  @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+
+  private File rootTestDir;
+  private ModelFileManager modelFileManager;
+
+  @Before
+  public void setup() {
+    MockitoAnnotations.initMocks(this);
+
+    rootTestDir =
+        new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
+    rootTestDir.mkdirs();
+    modelFileManager =
+        new ModelFileManager(
+            ApplicationProvider.getApplicationContext(),
+            new TextClassifierSettings(mockDeviceConfig));
+    setDefaultLocalesRule.set(new LocaleList(DEFAULT_LOCALE));
+  }
+
+  @After
+  public void removeTestDir() {
+    recursiveDelete(rootTestDir);
+  }
+
+  @Test
+  public void annotatorModelPreloaded() {
+    verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
+  }
+
+  @Test
+  public void actionsModelPreloaded() {
+    verifyModelPreloadedAsAsset(
+        ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
+  }
+
+  @Test
+  public void langIdModelPreloaded() {
+    verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
+  }
+
+  private void verifyModelPreloadedAsAsset(
+      @ModelTypeDef String modelType, String expectedModelPath) {
+    List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
+    List<ModelFile> assetFiles =
+        modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
+
+    assertThat(assetFiles).hasSize(1);
+    assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
+  }
+
+  @Test
+  public void findBestModel_versionCode() {
+    ModelFileManager.ModelFile olderModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile newerModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 2);
+    ModelFileManager modelFileManager = createModelFileManager(olderModelFile, newerModelFile);
+
+    ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
+    assertThat(bestModelFile).isEqualTo(newerModelFile);
+  }
+
+  @Test
+  public void findBestModel_languageDependentModelIsPreferred() {
+    ModelFileManager.ModelFile languageIndependentModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile languageDependentModelFile =
+        createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
+    ModelFileManager modelFileManager =
+        createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+    ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
+    assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+  }
+
+  @Test
+  public void findBestModel_noMatchedLanguageModel() {
+    ModelFileManager.ModelFile languageIndependentModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile languageDependentModelFile =
+        createModelFile("zh-hk", /* version */ 1);
+    ModelFileManager modelFileManager =
+        createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+    ModelFileManager.ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
+    assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+  }
+
+  @Test
+  public void findBestModel_languageIsMoreImportantThanVersion() {
+    ModelFileManager.ModelFile matchButOlderModel =
+        createModelFile(DEFAULT_LOCALE.toLanguageTag(), /* version */ 1);
+    ModelFileManager.ModelFile mismatchButNewerModel = createModelFile("zh-hk", /* version */ 2);
+    ModelFileManager modelFileManager =
+        createModelFileManager(matchButOlderModel, mismatchButNewerModel);
+
+    ModelFileManager.ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
+    assertThat(bestModelFile).isEqualTo(matchButOlderModel);
+  }
+
+  @Test
+  public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_onlyCheckLanguage() {
+    setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh"));
+    ModelFileManager.ModelFile languageIndependentModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+    ModelFileManager modelFileManager =
+        createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+    ModelFileManager.ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
+    assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+  }
+
+  @Test
+  public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_match() {
+    setDefaultLocalesRule.set(LocaleList.forLanguageTags("zh-hk"));
+    ModelFileManager.ModelFile languageIndependentModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+    ModelFileManager modelFileManager =
+        createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+    ModelFileManager.ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh"));
+    assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
+  }
+
+  @Test
+  public void findBestModel_filterOutLocalePreferenceNotInDefaultLocaleList_doNotMatch() {
+    setDefaultLocalesRule.set(LocaleList.forLanguageTags("en"));
+    ModelFileManager.ModelFile languageIndependentModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile languageDependentModelFile = createModelFile("zh", /* version */ 1);
+    ModelFileManager modelFileManager =
+        createModelFileManager(languageIndependentModelFile, languageDependentModelFile);
+
+    ModelFileManager.ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh"));
+    assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+  }
+
+  @Test
+  public void findBestModel_onlyPrimaryLocaleConsidered_noLocalePreferencesProvided() {
+    setDefaultLocalesRule.set(
+        new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+    ModelFileManager.ModelFile languageIndependentModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile nonPrimaryLocaleModelFile =
+        createModelFile("zh-hk", /* version */ 1);
+    ModelFileManager modelFileManager =
+        createModelFileManager(languageIndependentModelFile, nonPrimaryLocaleModelFile);
+
+    ModelFileManager.ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(MODEL_TYPE, /* localePreferences= */ null);
+    assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+  }
+
+  @Test
+  public void findBestModel_onlyPrimaryLocaleConsidered_localePreferencesProvided() {
+    setDefaultLocalesRule.set(
+        new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+
+    ModelFileManager.ModelFile languageIndependentModelFile =
+        createModelFile(LANGUAGE_INDEPENDENT, /* version */ 1);
+    ModelFileManager.ModelFile nonPrimaryLocalePreferenceModelFile =
+        createModelFile("zh-hk", /* version */ 1);
+    ModelFileManager modelFileManager =
+        createModelFileManager(languageIndependentModelFile, nonPrimaryLocalePreferenceModelFile);
+
+    ModelFileManager.ModelFile bestModelFile =
+        modelFileManager.findBestModelFile(
+            MODEL_TYPE,
+            new LocaleList(Locale.forLanguageTag("en"), Locale.forLanguageTag("zh-hk")));
+    assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
+  }
+
+  @Test
+  public void modelFileEquals() {
+    ModelFileManager.ModelFile modelA =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+    ModelFileManager.ModelFile modelB =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+    assertThat(modelA).isEqualTo(modelB);
+  }
+
+  @Test
+  public void modelFile_different() {
+    ModelFileManager.ModelFile modelA =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+    ModelFileManager.ModelFile modelB =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+    assertThat(modelA).isNotEqualTo(modelB);
+  }
+
+  @Test
+  public void modelFile_isPreferredTo_languageDependentIsBetter() {
+    ModelFileManager.ModelFile modelA =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
+
+    ModelFileManager.ModelFile modelB =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
+
+    assertThat(modelA.isPreferredTo(modelB)).isTrue();
+  }
+
+  @Test
+  public void modelFile_isPreferredTo_version() {
+    ModelFileManager.ModelFile modelA =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+    ModelFileManager.ModelFile modelB =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
+
+    assertThat(modelA.isPreferredTo(modelB)).isTrue();
+  }
+
+  @Test
+  public void modelFile_toModelInfo() {
+    ModelFileManager.ModelFile modelFile =
+        new ModelFileManager.ModelFile(
+            MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+    ModelInfo modelInfo = modelFile.toModelInfo();
+
+    assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
+  }
+
+  @Test
+  public void modelFile_toModelInfos() {
+    ModelFile englishModelFile =
+        new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
+    ModelFile japaneseModelFile =
+        new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
+
+    ImmutableList<Optional<ModelInfo>> modelInfos =
+        ModelFileManager.ModelFile.toModelInfos(
+            Optional.of(englishModelFile), Optional.of(japaneseModelFile));
+
+    assertThat(
+            modelInfos.stream()
+                .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
+                .collect(Collectors.toList()))
+        .containsExactly("en_v1", "ja_v2")
+        .inOrder();
+  }
+
+  @Test
+  public void regularFileFullMatchLister() throws IOException {
+    File modelFile = new File(rootTestDir, "test.model");
+    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
+    File wrongFile = new File(rootTestDir, "wrong.model");
+    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
+
+    RegularFileFullMatchLister regularFileFullMatchLister =
+        new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
+    ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
+
+    assertThat(listedModels).hasSize(1);
+    assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
+    assertThat(listedModels.get(0).isAsset).isFalse();
+  }
+
+  @Test
+  public void regularFilePatternMatchLister() throws IOException {
+    File modelFile1 = new File(rootTestDir, "annotator.en.model");
+    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+    File modelFile2 = new File(rootTestDir, "annotator.fr.model");
+    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
+    File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
+    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
+
+    RegularFilePatternMatchLister regularFilePatternMatchLister =
+        new RegularFilePatternMatchLister(
+            MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
+    ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+    assertThat(listedModels).hasSize(2);
+    assertThat(listedModels.get(0).isAsset).isFalse();
+    assertThat(listedModels.get(1).isAsset).isFalse();
+    assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
+        .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
+  }
+
+  @Test
+  public void regularFilePatternMatchLister_disabled() throws IOException {
+    File modelFile1 = new File(rootTestDir, "annotator.en.model");
+    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+
+    RegularFilePatternMatchLister regularFilePatternMatchLister =
+        new RegularFilePatternMatchLister(
+            MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
+    ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+    assertThat(listedModels).isEmpty();
+  }
+
+  private ModelFileManager createModelFileManager(ModelFile... modelFiles) {
+    return new ModelFileManager(
+        ApplicationProvider.getApplicationContext(),
+        ImmutableList.of(modelType -> ImmutableList.copyOf(modelFiles)));
+  }
+
+  private ModelFileManager.ModelFile createModelFile(String supportedLocaleTags, int version) {
+    return new ModelFileManager.ModelFile(
+        MODEL_TYPE,
+        new File(rootTestDir, String.format("%s-%d", supportedLocaleTags, version))
+            .getAbsolutePath(),
+        version,
+        supportedLocaleTags,
+        /* isAsset= */ false);
+  }
+
+  private static void recursiveDelete(File f) {
+    if (f.isDirectory()) {
+      for (File innerFile : f.listFiles()) {
+        recursiveDelete(innerFile);
+      }
+    }
+    f.delete();
+  }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index 5c1d95e..48f71f3 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -17,8 +17,7 @@
 package com.android.textclassifier;
 
 import android.content.Context;
-import com.android.textclassifier.common.ModelFileManager;
-import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
 import com.android.textclassifier.common.ModelType;
 import com.google.common.collect.ImmutableList;
 import java.io.File;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 81aa832..5c36008 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -41,7 +41,6 @@
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.ext.junit.runners.AndroidJUnit4;
 import androidx.test.filters.SmallTest;
-import com.android.textclassifier.common.ModelFileManager;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.testing.FakeContextBuilder;
 import com.google.common.collect.ImmutableList;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
deleted file mode 100644
index 40838ac..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/ModelFileManagerTest.java
+++ /dev/null
@@ -1,507 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.common;
-
-import static com.android.textclassifier.common.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
-import static com.google.common.truth.Truth.assertThat;
-
-import android.os.LocaleList;
-import androidx.test.core.app.ApplicationProvider;
-import androidx.test.ext.junit.runners.AndroidJUnit4;
-import androidx.test.filters.SmallTest;
-import com.android.textclassifier.TestDataUtils;
-import com.android.textclassifier.common.ModelFileManager.ModelFile;
-import com.android.textclassifier.common.ModelFileManager.RegularFileFullMatchLister;
-import com.android.textclassifier.common.ModelFileManager.RegularFilePatternMatchLister;
-import com.android.textclassifier.common.ModelType.ModelTypeDef;
-import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
-import com.android.textclassifier.testing.SetDefaultLocalesRule;
-import com.google.common.base.Optional;
-import com.google.common.collect.ImmutableList;
-import com.google.common.io.Files;
-import java.io.File;
-import java.io.IOException;
-import java.util.List;
-import java.util.Locale;
-import java.util.stream.Collectors;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-@SmallTest
-@RunWith(AndroidJUnit4.class)
-public final class ModelFileManagerTest {
-  private static final Locale DEFAULT_LOCALE = Locale.forLanguageTag("en-US");
-
-  @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
-
-  @Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
-
-  @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
-
-  private File rootTestDir;
-  private ModelFileManager modelFileManager;
-
-  @Before
-  public void setup() {
-    MockitoAnnotations.initMocks(this);
-
-    rootTestDir =
-        new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
-    rootTestDir.mkdirs();
-    modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            new TextClassifierSettings(mockDeviceConfig));
-  }
-
-  @After
-  public void removeTestDir() {
-    recursiveDelete(rootTestDir);
-  }
-
-  @Test
-  public void annotatorModelPreloaded() {
-    verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
-  }
-
-  @Test
-  public void actionsModelPreloaded() {
-    verifyModelPreloadedAsAsset(
-        ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
-  }
-
-  @Test
-  public void langIdModelPreloaded() {
-    verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
-  }
-
-  private void verifyModelPreloadedAsAsset(
-      @ModelTypeDef String modelType, String expectedModelPath) {
-    List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
-    List<ModelFile> assetFiles =
-        modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
-
-    assertThat(assetFiles).hasSize(1);
-    assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
-  }
-
-  @Test
-  public void findBestModel_versionCode() {
-    ModelFileManager.ModelFile olderModelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "a").getAbsolutePath(),
-            /* version= */ 1,
-            LANGUAGE_INDEPENDENT,
-            /* isAsset= */ false);
-    ModelFileManager.ModelFile newerModelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "b").getAbsolutePath(),
-            /* version= */ 2,
-            LANGUAGE_INDEPENDENT,
-            /* isAsset= */ false);
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(modelType -> ImmutableList.of(olderModelFile, newerModelFile)));
-
-    ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
-    assertThat(bestModelFile).isEqualTo(newerModelFile);
-  }
-
-  @Test
-  public void findBestModel_languageDependentModelIsPreferred() {
-    ModelFileManager.ModelFile languageIndependentModelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "a").getAbsolutePath(),
-            /* version= */ 1,
-            LANGUAGE_INDEPENDENT,
-            /* isAsset= */ false);
-    ModelFileManager.ModelFile languageDependentModelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "b").getAbsolutePath(),
-            /* version= */ 2,
-            DEFAULT_LOCALE.toLanguageTag(),
-            /* isAsset= */ false);
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(
-                modelType ->
-                    ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
-
-    ModelFile bestModelFile =
-        modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
-    assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
-  }
-
-  @Test
-  public void findBestModel_noMatchedLanguageModel() {
-    ModelFileManager.ModelFile languageIndependentModelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "a").getAbsolutePath(),
-            /* version= */ 1,
-            LANGUAGE_INDEPENDENT,
-            /* isAsset= */ false);
-    ModelFileManager.ModelFile languageDependentModelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "b").getAbsolutePath(),
-            /* version= */ 2,
-            DEFAULT_LOCALE.toLanguageTag(),
-            /* isAsset= */ false);
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(
-                modelType ->
-                    ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
-
-    ModelFileManager.ModelFile bestModelFile =
-        modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
-    assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
-  }
-
-  @Test
-  public void findBestModel_languageIsMoreImportantThanVersion() {
-    ModelFileManager.ModelFile matchButOlderModel =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "a").getAbsolutePath(),
-            /* version= */ 1,
-            "fr",
-            /* isAsset= */ false);
-    ModelFileManager.ModelFile mismatchButNewerModel =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "b").getAbsolutePath(),
-            /* version= */ 1,
-            "ja",
-            /* isAsset= */ false);
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(
-                modelType -> ImmutableList.of(matchButOlderModel, mismatchButNewerModel)));
-
-    ModelFileManager.ModelFile bestModelFile =
-        modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
-    assertThat(bestModelFile).isEqualTo(matchButOlderModel);
-  }
-
-  @Test
-  public void findBestModel_preferMatchedLocaleModel() {
-    ModelFileManager.ModelFile matchLocaleModel =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "a").getAbsolutePath(),
-            /* version= */ 1,
-            "ja",
-            /* isAsset= */ false);
-    ModelFileManager.ModelFile languageIndependentModel =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            new File(rootTestDir, "b").getAbsolutePath(),
-            /* version= */ 1,
-            LANGUAGE_INDEPENDENT,
-            /* isAsset= */ false);
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(
-                modelType -> ImmutableList.of(matchLocaleModel, languageIndependentModel)));
-
-    ModelFileManager.ModelFile bestModelFile =
-        modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
-
-    assertThat(bestModelFile).isEqualTo(matchLocaleModel);
-  }
-
-  @Test
-  public void deleteUnusedModelFiles_olderModelDeleted() throws Exception {
-    File model1 = new File(rootTestDir, "model1.fb");
-    model1.createNewFile();
-    File model2 = new File(rootTestDir, "model2.fb");
-    model2.createNewFile();
-    ModelFileManager.ModelFile modelFile1 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
-    ModelFileManager.ModelFile modelFile2 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "ja", /* isAsset= */ false);
-    setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
-    modelFileManager.deleteUnusedModelFiles();
-
-    assertThat(model1.exists()).isFalse();
-    assertThat(model2.exists()).isTrue();
-  }
-
-  @Test
-  public void deleteUnusedModelFiles_languageIndependentOlderModelDeleted() throws Exception {
-    File model1 = new File(rootTestDir, "model1.fb");
-    model1.createNewFile();
-    File model2 = new File(rootTestDir, "model2.fb");
-    model2.createNewFile();
-    ModelFileManager.ModelFile modelFile1 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            model1.getAbsolutePath(),
-            /* version= */ 1,
-            LANGUAGE_INDEPENDENT,
-            /* isAsset= */ false);
-    ModelFileManager.ModelFile modelFile2 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE,
-            model2.getAbsolutePath(),
-            /* version= */ 2,
-            LANGUAGE_INDEPENDENT,
-            /* isAsset= */ false);
-    setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
-    modelFileManager.deleteUnusedModelFiles();
-
-    assertThat(model1.exists()).isFalse();
-    assertThat(model2.exists()).isTrue();
-  }
-
-  @Test
-  public void deleteUnusedModelFiles_modelOnlySupportingLocalesNotInListDeleted() throws Exception {
-    File model1 = new File(rootTestDir, "model1.fb");
-    model1.createNewFile();
-    File model2 = new File(rootTestDir, "model2.fb");
-    model2.createNewFile();
-    ModelFileManager.ModelFile modelFile1 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
-    ModelFileManager.ModelFile modelFile2 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 1, "en", /* isAsset= */ false);
-    setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-
-    modelFileManager.deleteUnusedModelFiles();
-
-    assertThat(model1.exists()).isTrue();
-    assertThat(model2.exists()).isFalse();
-  }
-
-  @Test
-  public void deleteUnusedModelFiles_multiLocalesInLocaleList() throws Exception {
-    File model1 = new File(rootTestDir, "model1.fb");
-    model1.createNewFile();
-    File model2 = new File(rootTestDir, "model2.fb");
-    model2.createNewFile();
-    ModelFileManager.ModelFile modelFile1 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
-    ModelFileManager.ModelFile modelFile2 =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "en", /* isAsset= */ false);
-    setDefaultLocalesRule.set(
-        new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("en")));
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
-    modelFileManager.deleteUnusedModelFiles();
-
-    assertThat(model1.exists()).isTrue();
-    assertThat(model2.exists()).isTrue();
-  }
-
-  @Test
-  public void deleteUnusedModelFiles_readOnlyModelsUntouched() throws Exception {
-    File readOnlyDir = new File(rootTestDir, "read_only/");
-    readOnlyDir.mkdirs();
-    File model1 = new File(readOnlyDir, "model1.fb");
-    model1.createNewFile();
-    readOnlyDir.setWritable(false);
-    ModelFileManager.ModelFile modelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
-    ModelFileManager modelFileManager =
-        new ModelFileManager(
-            ApplicationProvider.getApplicationContext(),
-            ImmutableList.of(modelType -> ImmutableList.of(modelFile)));
-    setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
-
-    modelFileManager.deleteUnusedModelFiles();
-
-    assertThat(model1.exists()).isTrue();
-  }
-
-  @Test
-  public void modelFileEquals() {
-    ModelFileManager.ModelFile modelA =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
-    ModelFileManager.ModelFile modelB =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
-    assertThat(modelA).isEqualTo(modelB);
-  }
-
-  @Test
-  public void modelFile_different() {
-    ModelFileManager.ModelFile modelA =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-    ModelFileManager.ModelFile modelB =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
-
-    assertThat(modelA).isNotEqualTo(modelB);
-  }
-
-  @Test
-  public void modelFile_isPreferredTo_languageDependentIsBetter() {
-    ModelFileManager.ModelFile modelA =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
-
-    ModelFileManager.ModelFile modelB =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
-
-    assertThat(modelA.isPreferredTo(modelB)).isTrue();
-  }
-
-  @Test
-  public void modelFile_isPreferredTo_version() {
-    ModelFileManager.ModelFile modelA =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
-    ModelFileManager.ModelFile modelB =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
-
-    assertThat(modelA.isPreferredTo(modelB)).isTrue();
-  }
-
-  @Test
-  public void modelFile_toModelInfo() {
-    ModelFileManager.ModelFile modelFile =
-        new ModelFileManager.ModelFile(
-            MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
-    ModelInfo modelInfo = modelFile.toModelInfo();
-
-    assertThat(modelInfo.toModelName()).isEqualTo("ja_v2");
-  }
-
-  @Test
-  public void modelFile_toModelInfos() {
-    ModelFile englishModelFile =
-        new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
-    ModelFile japaneseModelFile =
-        new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
-
-    ImmutableList<Optional<ModelInfo>> modelInfos =
-        ModelFileManager.ModelFile.toModelInfos(
-            Optional.of(englishModelFile), Optional.of(japaneseModelFile));
-
-    assertThat(
-            modelInfos.stream()
-                .map(modelFile -> modelFile.transform(ModelInfo::toModelName).or(""))
-                .collect(Collectors.toList()))
-        .containsExactly("en_v1", "ja_v2")
-        .inOrder();
-  }
-
-  @Test
-  public void regularFileFullMatchLister() throws IOException {
-    File modelFile = new File(rootTestDir, "test.model");
-    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
-    File wrongFile = new File(rootTestDir, "wrong.model");
-    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
-
-    RegularFileFullMatchLister regularFileFullMatchLister =
-        new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
-    ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
-
-    assertThat(listedModels).hasSize(1);
-    assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
-    assertThat(listedModels.get(0).isAsset).isFalse();
-  }
-
-  @Test
-  public void regularFilePatternMatchLister() throws IOException {
-    File modelFile1 = new File(rootTestDir, "annotator.en.model");
-    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
-    File modelFile2 = new File(rootTestDir, "annotator.fr.model");
-    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
-    File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
-    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
-
-    RegularFilePatternMatchLister regularFilePatternMatchLister =
-        new RegularFilePatternMatchLister(
-            MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
-    ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
-
-    assertThat(listedModels).hasSize(2);
-    assertThat(listedModels.get(0).isAsset).isFalse();
-    assertThat(listedModels.get(1).isAsset).isFalse();
-    assertThat(ImmutableList.of(listedModels.get(0).absolutePath, listedModels.get(1).absolutePath))
-        .containsExactly(modelFile1.getAbsolutePath(), modelFile2.getAbsolutePath());
-  }
-
-  @Test
-  public void regularFilePatternMatchLister_disabled() throws IOException {
-    File modelFile1 = new File(rootTestDir, "annotator.en.model");
-    Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
-
-    RegularFilePatternMatchLister regularFilePatternMatchLister =
-        new RegularFilePatternMatchLister(
-            MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
-    ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
-
-    assertThat(listedModels).isEmpty();
-  }
-
-  private static void recursiveDelete(File f) {
-    if (f.isDirectory()) {
-      for (File innerFile : f.listFiles()) {
-        recursiveDelete(innerFile);
-      }
-    }
-    f.delete();
-  }
-}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java
new file mode 100644
index 0000000..835f50b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelDatabaseTest.java
@@ -0,0 +1,398 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.testng.Assert.expectThrows;
+
+import android.content.Context;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ModelView;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import java.io.IOException;
+import java.util.List;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+
+@RunWith(AndroidJUnit4.class)
+public class DownloadedModelDatabaseTest {
+  private static final String MODEL_URL = "https://model.url";
+  private static final String MODEL_URL_2 = "https://model2.url";
+  private static final String MODEL_PATH = "/data/test.model";
+  private static final String MODEL_PATH_2 = "/data/test.model2";
+  private static final String MANIFEST_URL = "https://manifest.url";
+  private static final String MANIFEST_URL_2 = "https://manifest2.url";
+  private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+  private static final String MODEL_TYPE_2 = ModelType.ACTIONS_SUGGESTIONS;
+  private static final String LOCALE_TAG = "zh";
+
+  private DownloadedModelDatabase db;
+
+  @Before
+  public void createDb() {
+    Context context = ApplicationProvider.getApplicationContext();
+    db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+  }
+
+  @After
+  public void closeDb() throws IOException {
+    db.close();
+  }
+
+  @Test
+  public void insertModelAndRead() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    List<Model> models = db.dao().queryAllModels();
+    assertThat(models).containsExactly(model);
+  }
+
+  @Test
+  public void insertModelAndDelete() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    db.dao().deleteModels(ImmutableList.of(model));
+    List<Model> models = db.dao().queryAllModels();
+    assertThat(models).isEmpty();
+  }
+
+  @Test
+  public void insertManifestAndRead() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    List<Manifest> manifests = db.dao().queryAllManifests();
+    assertThat(manifests).containsExactly(manifest);
+  }
+
+  @Test
+  public void insertManifestAndDelete() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    db.dao().deleteManifests(ImmutableList.of(manifest));
+    List<Manifest> manifests = db.dao().queryAllManifests();
+    assertThat(manifests).isEmpty();
+  }
+
+  @Test
+  public void insertManifestModelCrossRefAndRead() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestModelCrossRef manifestModelCrossRef =
+        ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+    db.dao().insert(manifestModelCrossRef);
+    List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+    assertThat(manifestModelCrossRefs).containsExactly(manifestModelCrossRef);
+  }
+
+  @Test
+  public void insertManifestModelCrossRefAndDelete() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestModelCrossRef manifestModelCrossRef =
+        ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+    db.dao().insert(manifestModelCrossRef);
+    db.dao().deleteManifestModelCrossRefs(ImmutableList.of(manifestModelCrossRef));
+    List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+    assertThat(manifestModelCrossRefs).isEmpty();
+  }
+
+  @Test
+  public void insertManifestModelCrossRefAndDeleteManifest() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestModelCrossRef manifestModelCrossRef =
+        ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+    db.dao().insert(manifestModelCrossRef);
+    db.dao().deleteManifests(ImmutableList.of(manifest)); // ON CASCADE
+    List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+    assertThat(manifestModelCrossRefs).isEmpty();
+  }
+
+  @Test
+  public void insertManifestModelCrossRefAndDeleteModel() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestModelCrossRef manifestModelCrossRef =
+        ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+    db.dao().insert(manifestModelCrossRef);
+    db.dao().deleteModels(ImmutableList.of(model)); // ON CASCADE
+    List<ManifestModelCrossRef> manifestModelCrossRefs = db.dao().queryAllManifestModelCrossRefs();
+    assertThat(manifestModelCrossRefs).isEmpty();
+  }
+
+  @Test
+  public void insertManifestModelCrossRefWithoutManifest() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    ManifestModelCrossRef manifestModelCrossRef =
+        ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+    expectThrows(Throwable.class, () -> db.dao().insert(manifestModelCrossRef));
+  }
+
+  @Test
+  public void insertManifestModelCrossRefWithoutModel() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestModelCrossRef manifestModelCrossRef =
+        ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+    expectThrows(Throwable.class, () -> db.dao().insert(manifestModelCrossRef));
+  }
+
+  @Test
+  public void insertManifestEnrollmentAndRead() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+    List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+    assertThat(manifestEnrollments).containsExactly(manifestEnrollment);
+  }
+
+  @Test
+  public void insertManifestEnrollmentAndDelete() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+    db.dao().deleteManifestEnrollments(ImmutableList.of(manifestEnrollment));
+    List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+    assertThat(manifestEnrollments).isEmpty();
+  }
+
+  @Test
+  public void insertManifestEnrollmentAndDeleteManifest() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+    db.dao().deleteManifests(ImmutableList.of(manifest));
+    List<ManifestEnrollment> manifestEnrollments = db.dao().queryAllManifestEnrollments();
+    assertThat(manifestEnrollments).isEmpty();
+  }
+
+  @Test
+  public void insertManifestEnrollmentWithoutManifest() throws Exception {
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    expectThrows(Throwable.class, () -> db.dao().insert(manifestEnrollment));
+  }
+
+  @Test
+  public void insertModelViewAndRead() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    ManifestModelCrossRef manifestModelCrossRef =
+        ManifestModelCrossRef.create(MANIFEST_URL, MODEL_URL);
+    db.dao().insert(manifestModelCrossRef);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+
+    List<ModelView> modelViews = db.dao().queryAllModelViews();
+    ModelView modelView = Iterables.getOnlyElement(modelViews);
+    assertThat(modelView.getManifestEnrollment()).isEqualTo(manifestEnrollment);
+    assertThat(modelView.getModel()).isEqualTo(model);
+  }
+
+  @Test
+  public void queryModelWithModelUrl() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+    db.dao().insert(model2);
+
+    assertThat(db.dao().queryModelWithModelUrl(MODEL_URL)).containsExactly(model);
+    assertThat(db.dao().queryModelWithModelUrl(MODEL_URL_2)).containsExactly(model2);
+  }
+
+  @Test
+  public void queryManifestWithManifestUrl() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    Manifest manifest2 =
+        Manifest.create(MANIFEST_URL_2, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+    db.dao().insert(manifest2);
+
+    assertThat(db.dao().queryManifestWithManifestUrl(MANIFEST_URL)).containsExactly(manifest);
+    assertThat(db.dao().queryManifestWithManifestUrl(MANIFEST_URL_2)).containsExactly(manifest2);
+  }
+
+  @Test
+  public void queryManifestEnrollmentWithModelTypeAndLocaleTag() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    Manifest manifest2 =
+        Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest2);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+    ManifestEnrollment manifestEnrollment2 =
+        ManifestEnrollment.create(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+    db.dao().insert(manifestEnrollment2);
+
+    assertThat(db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(MODEL_TYPE, LOCALE_TAG))
+        .containsExactly(manifestEnrollment);
+    assertThat(db.dao().queryManifestEnrollmentWithModelTypeAndLocaleTag(MODEL_TYPE_2, LOCALE_TAG))
+        .containsExactly(manifestEnrollment2);
+  }
+
+  @Test
+  public void insertManifestAndModelCrossRef() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+
+    assertThat(db.dao().queryAllModels()).containsExactly(model);
+    assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+  }
+
+  @Test
+  public void increaseManifestFailureCounts() throws Exception {
+    db.dao().increaseManifestFailureCounts(MODEL_URL);
+    Manifest manifest = Iterables.getOnlyElement(db.dao().queryManifestWithManifestUrl(MODEL_URL));
+    assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+    assertThat(manifest.getFailureCounts()).isEqualTo(1);
+    db.dao().increaseManifestFailureCounts(MODEL_URL);
+    manifest = Iterables.getOnlyElement(db.dao().queryManifestWithManifestUrl(MODEL_URL));
+    assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+    assertThat(manifest.getFailureCounts()).isEqualTo(2);
+  }
+
+  @Test
+  public void deleteUnusedManifestsAndModels_unusedManifestAndUnusedModel() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+    db.dao().insert(model2);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+    Manifest manifest2 =
+        Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest2);
+    db.dao().insertManifestAndModelCrossRef(MANIFEST_URL_2, MODEL_URL_2);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+
+    db.dao().deleteUnusedManifestsAndModels();
+    assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+    assertThat(db.dao().queryAllModels()).containsExactly(model);
+  }
+
+  @Test
+  public void deleteUnusedManifestsAndModels_unusedManifestAndSharedModel() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+    Manifest manifest2 =
+        Manifest.create(MANIFEST_URL_2, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest2);
+    db.dao().insertManifestAndModelCrossRef(MANIFEST_URL_2, MODEL_URL);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+
+    db.dao().deleteUnusedManifestsAndModels();
+    assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+    assertThat(db.dao().queryAllModels()).containsExactly(model);
+  }
+
+  @Test
+  public void deleteUnusedManifestsAndModels_failedManifest() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+    db.dao().insert(manifest);
+
+    db.dao().deleteUnusedManifestsAndModels();
+    assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+  }
+
+  @Test
+  public void deleteUnusedManifestsAndModels_unusedModels() throws Exception {
+    Model model = Model.create(MODEL_URL, MODEL_PATH);
+    db.dao().insert(model);
+    Model model2 = Model.create(MODEL_URL_2, MODEL_PATH_2);
+    db.dao().insert(model2);
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0);
+    db.dao().insert(manifest);
+    db.dao().insertManifestAndModelCrossRef(MANIFEST_URL, MODEL_URL);
+    ManifestEnrollment manifestEnrollment =
+        ManifestEnrollment.create(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    db.dao().insert(manifestEnrollment);
+
+    db.dao().deleteUnusedManifestsAndModels();
+    assertThat(db.dao().queryAllModels()).containsExactly(model);
+  }
+
+  @Test
+  public void deleteUnusedManifestFailureRecords() throws Exception {
+    Manifest manifest =
+        Manifest.create(MANIFEST_URL, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+    db.dao().insert(manifest);
+    Manifest manifest2 =
+        Manifest.create(MANIFEST_URL_2, Manifest.STATUS_FAILED, /* failureCounts= */ 1);
+    db.dao().insert(manifest2);
+
+    db.dao().deleteUnusedManifestFailureRecords(ImmutableList.of(MANIFEST_URL));
+    assertThat(db.dao().queryAllManifests()).containsExactly(manifest);
+  }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java
new file mode 100644
index 0000000..2715a05
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloadedModelManagerImplTest.java
@@ -0,0 +1,280 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import android.util.Pair;
+import androidx.room.Room;
+import androidx.test.core.app.ApplicationProvider;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.ModelType.ModelTypeDef;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Manifest;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestEnrollment;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.ManifestModelCrossRef;
+import com.android.textclassifier.downloader.DownloadedModelDatabase.Model;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableMap;
+import java.io.File;
+import java.util.Map;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class DownloadedModelManagerImplTest {
+
+  private File modelDownloaderDir;
+  private DownloadedModelDatabase db;
+  private DownloadedModelManagerImpl downloadedModelManagerImpl;
+  private TestingDeviceConfig deviceConfig;
+  private TextClassifierSettings settings;
+
+  @Before
+  public void setUp() {
+    Context context = ApplicationProvider.getApplicationContext();
+    modelDownloaderDir = new File(context.getFilesDir(), "test_dir");
+    modelDownloaderDir.mkdirs();
+    deviceConfig = new TestingDeviceConfig();
+    settings = new TextClassifierSettings(deviceConfig);
+    db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+    downloadedModelManagerImpl =
+        DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings);
+  }
+
+  @After
+  public void cleanUp() {
+    DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
+    db.close();
+  }
+
+  @Test
+  public void getModelDownloaderDir() throws Exception {
+    modelDownloaderDir.delete();
+    assertThat(downloadedModelManagerImpl.getModelDownloaderDir().exists()).isTrue();
+    assertThat(downloadedModelManagerImpl.getModelDownloaderDir()).isEqualTo(modelDownloaderDir);
+  }
+
+  @Test
+  public void listModels_cacheNotInitialized() throws Exception {
+    registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+    registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(new File("modelPathEn"), new File("modelPathZh"));
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.LANG_ID)).isEmpty();
+  }
+
+  @Test
+  public void listModels_doNotListBlockedModels() throws Exception {
+    registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+    registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+    deviceConfig.setConfig(
+        TextClassifierSettings.MODEL_URL_BLOCKLIST,
+        String.format(
+            "%s%s%s",
+            "modelUrlEn", TextClassifierSettings.MODEL_URL_BLOCKLIST_SEPARATOR, "modelUrlXX"));
+
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(new File("modelPathZh"));
+  }
+
+  @Test
+  public void listModels_cacheNotUpdatedUnlessOnDownloadCompleted() throws Exception {
+    registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrlEn", "modelUrlEn", "modelPathEn");
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(new File("modelPathEn"));
+
+    registerManifestToDB(ModelType.ANNOTATOR, "zh", "manifestUrlZh", "modelUrlZh", "modelPathZh");
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(new File("modelPathEn"));
+
+    Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+        ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("zh", "manifestUrlZh"));
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(new File("modelPathZh"));
+  }
+
+  @Test
+  public void getModel() throws Exception {
+    registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+    assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
+        .isEqualTo("modelPath");
+    assertThat(downloadedModelManagerImpl.getModel("modelUrl2")).isNull();
+  }
+
+  @Test
+  public void getManifest() throws Exception {
+    registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+    assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
+    assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
+  }
+
+  @Test
+  public void getManifestEnrollment() throws Exception {
+    registerManifestToDB(ModelType.ANNOTATOR, "en", "manifestUrl", "modelUrl", "modelPath");
+    assertThat(
+            downloadedModelManagerImpl
+                .getManifestEnrollment(ModelType.ANNOTATOR, "en")
+                .getManifestUrl())
+        .isEqualTo("manifestUrl");
+    assertThat(downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "zh"))
+        .isNull();
+  }
+
+  @Test
+  public void registerModel() throws Exception {
+    downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+
+    assertThat(downloadedModelManagerImpl.getModel("modelUrl").getModelPath())
+        .isEqualTo("modelPath");
+  }
+
+  @Test
+  public void registerManifest() throws Exception {
+    downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+    downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
+
+    assertThat(downloadedModelManagerImpl.getManifest("manifestUrl")).isNotNull();
+  }
+
+  @Test
+  public void registerManifestDownloadFailure() throws Exception {
+    downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl");
+
+    Manifest manifest = downloadedModelManagerImpl.getManifest("manifestUrl");
+    assertThat(manifest.getStatus()).isEqualTo(Manifest.STATUS_FAILED);
+    assertThat(manifest.getFailureCounts()).isEqualTo(1);
+  }
+
+  @Test
+  public void registerManifestEnrollment() throws Exception {
+    downloadedModelManagerImpl.registerModel("modelUrl", "modelPath");
+    downloadedModelManagerImpl.registerManifest("manifestUrl", "modelUrl");
+    downloadedModelManagerImpl.registerManifestEnrollment(ModelType.ANNOTATOR, "en", "manifestUrl");
+
+    ManifestEnrollment manifestEnrollment =
+        downloadedModelManagerImpl.getManifestEnrollment(ModelType.ANNOTATOR, "en");
+    assertThat(manifestEnrollment.getModelType()).isEqualTo(ModelType.ANNOTATOR);
+    assertThat(manifestEnrollment.getLocaleTag()).isEqualTo("en");
+    assertThat(manifestEnrollment.getManifestUrl()).isEqualTo("manifestUrl");
+  }
+
+  @Test
+  public void onDownloadCompleted_newModelDownloaded() throws Exception {
+    Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+        ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+    File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+    modelFile1.createNewFile();
+    registerManifestToDB(
+        ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+    assertThat(modelFile1.exists()).isTrue();
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(modelFile1);
+
+    modelTypeToLocaleTagAndManifestUrls =
+        ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl2"));
+    File modelFile2 = new File(modelDownloaderDir, "modelFile2");
+    modelFile2.createNewFile();
+    registerManifestToDB(
+        ModelType.ANNOTATOR, "en", "manifestUrl2", "modelUrl2", modelFile2.getAbsolutePath());
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+    assertThat(modelFile1.exists()).isFalse();
+    assertThat(modelFile2.exists()).isTrue();
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(modelFile2);
+  }
+
+  @Test
+  public void onDownloadCompleted_newModelDownloadFailed() throws Exception {
+    Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+        ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+    File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+    modelFile1.createNewFile();
+    registerManifestToDB(
+        ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+    assertThat(modelFile1.exists()).isTrue();
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(modelFile1);
+
+    modelTypeToLocaleTagAndManifestUrls =
+        ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl2"));
+    downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+    assertThat(modelFile1.exists()).isTrue();
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(modelFile1);
+  }
+
+  @Test
+  public void onDownloadCompleted_flatUnset() throws Exception {
+    Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+        ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+    File modelFile1 = new File(modelDownloaderDir, "modelFile1");
+    modelFile1.createNewFile();
+    registerManifestToDB(
+        ModelType.ANNOTATOR, "en", "manifestUrl1", "modelUrl1", modelFile1.getAbsolutePath());
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+    assertThat(modelFile1.exists()).isTrue();
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR))
+        .containsExactly(modelFile1);
+
+    modelTypeToLocaleTagAndManifestUrls = ImmutableMap.of();
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+    assertThat(modelFile1.exists()).isFalse();
+    assertThat(downloadedModelManagerImpl.listModels(ModelType.ANNOTATOR)).isEmpty();
+  }
+
+  @Test
+  public void onDownloadCompleted_cleanUpFailureRecords() throws Exception {
+    Map<String, Pair<String, String>> modelTypeToLocaleTagAndManifestUrls =
+        ImmutableMap.of(ModelType.ANNOTATOR, Pair.create("en", "manifestUrl1"));
+    downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl1");
+    downloadedModelManagerImpl.registerManifestDownloadFailure("manifestUrl2");
+    downloadedModelManagerImpl.onDownloadCompleted(modelTypeToLocaleTagAndManifestUrls);
+
+    assertThat(downloadedModelManagerImpl.getManifest("manifestUrl1").getStatus())
+        .isEqualTo(Manifest.STATUS_FAILED);
+    assertThat(downloadedModelManagerImpl.getManifest("manifestUrl2")).isNull();
+  }
+
+  private void registerManifestToDB(
+      @ModelTypeDef String modelType,
+      String localeTag,
+      String manifestUrl,
+      String modelUrl,
+      String modelPath) {
+    db.dao().insert(Model.create(modelUrl, modelPath));
+    db.dao()
+        .insert(Manifest.create(manifestUrl, Manifest.STATUS_SUCCEEDED, /* failureCounts= */ 0));
+    db.dao().insert(ManifestModelCrossRef.create(manifestUrl, modelUrl));
+    db.dao().insert(ManifestEnrollment.create(modelType, localeTag, manifestUrl));
+  }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
index 1337130..37394e6 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/DownloaderTestUtils.java
@@ -16,80 +16,21 @@
 
 package com.android.textclassifier.downloader;
 
-import android.content.Context;
-import androidx.work.ListenableWorker;
 import androidx.work.WorkInfo;
 import androidx.work.WorkManager;
 import androidx.work.WorkQuery;
-import androidx.work.WorkerParameters;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Iterables;
-import com.google.common.util.concurrent.Futures;
-import com.google.common.util.concurrent.ListenableFuture;
 import java.io.File;
 import java.util.List;
 
 /** Utils for downloader logic testing. */
 final class DownloaderTestUtils {
 
-  /** One unique queue holds at most one request at one time. Returns null if no WorkInfo found. */
-  public static WorkInfo queryTheOnlyWorkInfo(WorkManager workManager, String queueName)
+  public static List<WorkInfo> queryWorkInfos(WorkManager workManager, String queueName)
       throws Exception {
     WorkQuery workQuery =
         WorkQuery.Builder.fromUniqueWorkNames(ImmutableList.of(queueName)).build();
-    List<WorkInfo> workInfos = workManager.getWorkInfos(workQuery).get();
-    if (workInfos.isEmpty()) {
-      return null;
-    } else {
-      return Iterables.getOnlyElement(workInfos);
-    }
-  }
-
-  /**
-   * Completes immediately with the pre-set result. If it's not retry, the result will also include
-   * the input Data as its output Data.
-   */
-  public static final class TestWorker extends ListenableWorker {
-    private static Result expectedResult;
-
-    public TestWorker(Context context, WorkerParameters workerParams) {
-      super(context, workerParams);
-    }
-
-    @Override
-    public ListenableFuture<ListenableWorker.Result> startWork() {
-      if (expectedResult == null) {
-        return Futures.immediateFailedFuture(new Exception("no expected result"));
-      }
-      ListenableWorker.Result result;
-      switch (expectedResult) {
-        case SUCCESS:
-          result = ListenableWorker.Result.success(getInputData());
-          break;
-        case FAILURE:
-          result = ListenableWorker.Result.failure(getInputData());
-          break;
-        case RETRY:
-          result = ListenableWorker.Result.retry();
-          break;
-        default:
-          throw new IllegalStateException("illegal result");
-      }
-      // Reset expected result
-      expectedResult = null;
-      return Futures.immediateFuture(result);
-    }
-
-    /** Sets the expected worker result in a static variable. Will be cleaned up after reading. */
-    public static void setExpectedResult(Result expectedResult) {
-      TestWorker.expectedResult = expectedResult;
-    }
-
-    public enum Result {
-      SUCCESS,
-      FAILURE,
-      RETRY;
-    }
+    return workManager.getWorkInfos(workQuery).get();
   }
 
   // MoreFiles#deleteRecursively is not available for Android guava.
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java
new file mode 100644
index 0000000..a553c51
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/LocaleUtilsTest.java
@@ -0,0 +1,115 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.downloader;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.util.Pair;
+import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
+import com.android.textclassifier.testing.TestingDeviceConfig;
+import com.google.common.collect.ImmutableList;
+import java.util.Locale;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class LocaleUtilsTest {
+  private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+
+  private TestingDeviceConfig deviceConfig;
+  private TextClassifierSettings settings;
+
+  @Before
+  public void setUp() {
+    deviceConfig = new TestingDeviceConfig();
+    settings = new TextClassifierSettings(deviceConfig);
+  }
+
+  @Test
+  public void lookupBestLocaleTag_simpleMatch() {
+    assertThat(
+            LocaleUtils.lookupBestLocaleTag(
+                Locale.forLanguageTag("en"), ImmutableList.of("en", "zh")))
+        .isEqualTo("en");
+  }
+
+  @Test
+  public void lookupBestLocaleTag_noMatch() {
+    assertThat(LocaleUtils.lookupBestLocaleTag(Locale.forLanguageTag("en"), ImmutableList.of("zh")))
+        .isNull();
+    assertThat(
+            LocaleUtils.lookupBestLocaleTag(Locale.forLanguageTag("en"), ImmutableList.of("en-uk")))
+        .isNull();
+    assertThat(
+            LocaleUtils.lookupBestLocaleTag(
+                Locale.forLanguageTag("en-US"), ImmutableList.of("en-uk")))
+        .isNull();
+  }
+
+  @Test
+  public void lookupBestLocaleTag_partialMatch() {
+    assertThat(
+            LocaleUtils.lookupBestLocaleTag(
+                Locale.forLanguageTag("en-US"), ImmutableList.of("en", "zh")))
+        .isEqualTo("en");
+    assertThat(
+            LocaleUtils.lookupBestLocaleTag(
+                Locale.forLanguageTag("en-US"), ImmutableList.of("en", "en-us")))
+        .isEqualTo("en-us");
+    assertThat(
+            LocaleUtils.lookupBestLocaleTag(
+                Locale.forLanguageTag("en-US"), ImmutableList.of("en", "en-uk")))
+        .isEqualTo("en");
+  }
+
+  @Test
+  public void lookupBestLocaleTag_universalMatch() {
+    assertThat(
+            LocaleUtils.lookupBestLocaleTag(
+                Locale.forLanguageTag("en"),
+                ImmutableList.of("zh", LocaleUtils.UNIVERSAL_LOCALE_TAG)))
+        .isEqualTo(LocaleUtils.UNIVERSAL_LOCALE_TAG);
+  }
+
+  @Test
+  public void lookupBestLocaleTagAndManifestUrl_found() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, "en", "url_1");
+    Pair<String, String> pair =
+        LocaleUtils.lookupBestLocaleTagAndManifestUrl(
+            MODEL_TYPE, Locale.forLanguageTag("en"), settings);
+    assertThat(pair.first).isEqualTo("en");
+    assertThat(pair.second).isEqualTo("url_1");
+  }
+
+  @Test
+  public void lookupBestLocaleTagAndManifestUrl_notFound() throws Exception {
+    Pair<String, String> pair =
+        LocaleUtils.lookupBestLocaleTagAndManifestUrl(
+            MODEL_TYPE, Locale.forLanguageTag("en"), settings);
+    assertThat(pair).isNull();
+  }
+
+  private void setUpManifestUrl(
+      @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
+    String deviceConfigFlag =
+        String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
+    deviceConfig.setConfig(deviceConfigFlag, url);
+  }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
index 09593c2..351d00b 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloadManagerTest.java
@@ -17,19 +17,15 @@
 package com.android.textclassifier.downloader;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.when;
 
 import android.content.Context;
 import android.os.LocaleList;
 import androidx.test.core.app.ApplicationProvider;
 import androidx.test.ext.junit.runners.AndroidJUnit4;
-import androidx.work.ExistingWorkPolicy;
-import androidx.work.OneTimeWorkRequest;
 import androidx.work.WorkInfo;
 import androidx.work.WorkManager;
-import androidx.work.testing.TestDriver;
 import androidx.work.testing.WorkManagerTestInitHelper;
-import com.android.os.AtomsProto.TextClassifierDownloadReported;
-import com.android.textclassifier.common.ModelFileManager;
 import com.android.textclassifier.common.ModelType;
 import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
@@ -38,27 +34,23 @@
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.util.List;
 import java.util.Locale;
+import java.util.stream.Collectors;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
 @RunWith(AndroidJUnit4.class)
 public final class ModelDownloadManagerTest {
-  private static final String MANIFEST_URL =
-      "https://www.gstatic.com/android/text_classifier/x/v123/en.fb.manifest";
-  private static final String MANIFEST_URL_2 =
-      "https://www.gstatic.com/android/text_classifier/y/v456/zh.fb.manifest";
-  // Parameterized test is not yet supported for instrumentation test
+  private static final String MODEL_PATH = "/data/test.model";
   @ModelType.ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
-  private static final TextClassifierDownloadReported.ModelType MODEL_TYPE_ATOM =
-      TextClassifierDownloadReported.ModelType.ANNOTATOR;
   private static final String LOCALE_TAG = "en";
-  private static final String LOCALE_TAG_2 = "zh";
-  private static final String LOCALE_UNIVERSAL_TAG = ModelDownloadManager.UNIVERSAL_LOCALE_TAG;
   private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG));
 
   @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
@@ -70,8 +62,8 @@
   // TODO(licha): Maybe we can just use the real TextClassifierSettings
   private TestingDeviceConfig deviceConfig;
   private WorkManager workManager;
-  private TestDriver workManagerTestDriver;
   private ModelDownloadManager downloadManager;
+  @Mock DownloadedModelManager downloadedModelManager;
 
   @Before
   public void setUp() {
@@ -81,276 +73,84 @@
 
     this.deviceConfig = new TestingDeviceConfig();
     this.workManager = WorkManager.getInstance(context);
-    this.workManagerTestDriver = WorkManagerTestInitHelper.getTestDriver(context);
-    ModelFileManager modelFileManager = new ModelFileManager(context, ImmutableList.of());
     this.downloadManager =
         new ModelDownloadManager(
             context,
-            DownloaderTestUtils.TestWorker.class,
-            modelFileManager,
+            NewModelDownloadWorker.class,
+            downloadedModelManager,
             new TextClassifierSettings(deviceConfig),
             MoreExecutors.newDirectExecutorService());
+
     setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
   }
 
   @After
   public void tearDown() {
+    workManager.cancelUniqueWork(ModelDownloadManager.UNIQUE_QUEUE_NAME);
     DownloaderTestUtils.deleteRecursively(
         ApplicationProvider.getApplicationContext().getFilesDir());
   }
 
   @Test
   public void onTextClassifierServiceCreated_requestEnqueued() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
     downloadManager.onTextClassifierServiceCreated();
 
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
+    WorkInfo workInfo =
+        Iterables.getOnlyElement(
+            DownloaderTestUtils.queryWorkInfos(
+                workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
     assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
   }
 
   @Test
   public void onLocaleChanged_requestEnqueued() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
     downloadManager.onLocaleChanged();
 
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
+    WorkInfo workInfo =
+        Iterables.getOnlyElement(
+            DownloaderTestUtils.queryWorkInfos(
+                workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
     assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
   }
 
   @Test
   public void onTextClassifierDeviceConfigChanged_requestEnqueued() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
     downloadManager.onTextClassifierDeviceConfigChanged();
 
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
+    WorkInfo workInfo =
+        Iterables.getOnlyElement(
+            DownloaderTestUtils.queryWorkInfos(
+                workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME));
     assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
   }
 
   @Test
   public void onTextClassifierDeviceConfigChanged_downloaderDisabled() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
     deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, false);
     downloadManager.onTextClassifierDeviceConfigChanged();
 
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(workInfo).isNull();
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_flagNotSet() throws Exception {
-    downloadManager.onTextClassifierDeviceConfigChanged();
-
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(workInfo).isNull();
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_skipManifestProcessedBefore() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
-    // Simulates a previous model download task
-    OneTimeWorkRequest modelDownloadRequest =
-        new OneTimeWorkRequest.Builder(DownloaderTestUtils.TestWorker.class)
-            .addTag(MANIFEST_URL)
-            .build();
-    DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
-    workManager
-        .enqueueUniqueWork(
-            ModelDownloadManager.formatUniqueWorkName(MODEL_TYPE, LOCALE_TAG),
-            ExistingWorkPolicy.REPLACE,
-            modelDownloadRequest)
-        .getResult()
-        .get();
-
-    // Assert the model download work succeeded
-    WorkInfo oldWorkInfo =
-        DownloaderTestUtils.queryTheOnlyWorkInfo(
-            workManager, ModelDownloadManager.formatUniqueWorkName(MODEL_TYPE, LOCALE_TAG));
-    assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.SUCCEEDED);
-
-    // Trigger the config check
-    downloadManager.onTextClassifierDeviceConfigChanged();
-    WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(newWorkInfo.getId()).isEqualTo(oldWorkInfo.getId());
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_scheduleWorkForAllModelTypes() throws Exception {
-    for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
-      setUpManifestUrl(modelType, LOCALE_TAG, modelType + MANIFEST_URL);
-    }
-    downloadManager.onTextClassifierDeviceConfigChanged();
-
-    for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
-      WorkInfo workInfo = queryTheOnlyWorkInfo(modelType, LOCALE_TAG);
-      assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-    }
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_checkIsIdempotent() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
-    downloadManager.onTextClassifierDeviceConfigChanged();
-    WorkInfo oldWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    downloadManager.onTextClassifierDeviceConfigChanged();
-    WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-
-    // Will not schedule multiple times, still the same WorkInfo
-    assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-    assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-    assertThat(oldWorkInfo.getId()).isEqualTo(newWorkInfo.getId());
-    assertThat(oldWorkInfo.getTags()).containsExactlyElementsIn(newWorkInfo.getTags());
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_newWorkReplaceOldWork() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
-    downloadManager.onTextClassifierDeviceConfigChanged();
-    WorkInfo oldWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL_2);
-    downloadManager.onTextClassifierDeviceConfigChanged();
-    WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-
-    // oldWorkInfo will be replaced with the newWorkInfo
-    assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-    assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-    assertThat(oldWorkInfo.getId()).isNotEqualTo(newWorkInfo.getId());
-    assertThat(oldWorkInfo.getTags()).contains(MANIFEST_URL);
-    assertThat(newWorkInfo.getTags()).contains(MANIFEST_URL_2);
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_workerSucceeded() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
-    downloadManager.onTextClassifierDeviceConfigChanged();
-
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
-    DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
-    workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
-
-    WorkInfo succeededWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(succeededWorkInfo.getId()).isEqualTo(workInfo.getId());
-    assertThat(succeededWorkInfo.getState()).isEqualTo(WorkInfo.State.SUCCEEDED);
     assertThat(
-            succeededWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_MODEL_TYPE_KEY))
-        .isEqualTo(MODEL_TYPE);
-    assertThat(
-            succeededWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_LOCALE_TAG_KEY))
-        .isEqualTo(LOCALE_TAG);
-    assertThat(
-            succeededWorkInfo
-                .getOutputData()
-                .getString(NewModelDownloadWorker.DATA_MANIFEST_URL_KEY))
-        .isEqualTo(MANIFEST_URL);
+            DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME))
+        .isEmpty();
   }
 
   @Test
-  public void onTextClassifierDeviceConfigChanged_workerFailed() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+  public void onTextClassifierDeviceConfigChanged_newWorkDoNotReplaceOldWork() throws Exception {
     downloadManager.onTextClassifierDeviceConfigChanged();
+    downloadManager.onTextClassifierDeviceConfigChanged();
+    List<WorkInfo> workInfos =
+        DownloaderTestUtils.queryWorkInfos(workManager, ModelDownloadManager.UNIQUE_QUEUE_NAME);
 
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
-    DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.FAILURE);
-    workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
-
-    WorkInfo failedWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(failedWorkInfo.getId()).isEqualTo(workInfo.getId());
-    assertThat(failedWorkInfo.getState()).isEqualTo(WorkInfo.State.FAILED);
-    assertThat(failedWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_MODEL_TYPE_KEY))
-        .isEqualTo(MODEL_TYPE);
-    assertThat(failedWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_LOCALE_TAG_KEY))
-        .isEqualTo(LOCALE_TAG);
-    assertThat(
-            failedWorkInfo.getOutputData().getString(NewModelDownloadWorker.DATA_MANIFEST_URL_KEY))
-        .isEqualTo(MANIFEST_URL);
+    assertThat(workInfos.stream().map(WorkInfo::getState).collect(Collectors.toList()))
+        .containsExactly(WorkInfo.State.ENQUEUED, WorkInfo.State.BLOCKED);
   }
 
   @Test
-  public void onTextClassifierDeviceConfigChanged_workerRetried() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
-    downloadManager.onTextClassifierDeviceConfigChanged();
+  public void listDownloadedModels() throws Exception {
+    File modelFile = new File(MODEL_PATH);
+    when(downloadedModelManager.listModels(MODEL_TYPE)).thenReturn(ImmutableList.of(modelFile));
 
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
-    DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.RETRY);
-    workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
-
-    WorkInfo retryWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(retryWorkInfo.getId()).isEqualTo(workInfo.getId());
-    assertThat(retryWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-    assertThat(retryWorkInfo.getRunAttemptCount()).isEqualTo(1);
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_chooseTheBestLocaleTag() throws Exception {
-    // System default locale: zh-hant-hk
-    setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("zh-hant-hk")));
-
-    // All configured locale tags
-    setUpManifestUrl(MODEL_TYPE, "zh-hant", MANIFEST_URL); // best match
-    setUpManifestUrl(MODEL_TYPE, "zh", MANIFEST_URL_2); // too general
-    setUpManifestUrl(MODEL_TYPE, "zh-hk", MANIFEST_URL_2); // missing script
-    setUpManifestUrl(MODEL_TYPE, "zh-hans-hk", MANIFEST_URL_2); // incorrect script
-    setUpManifestUrl(MODEL_TYPE, "es-hant-hk", MANIFEST_URL_2); // incorrect language
-
-    downloadManager.onTextClassifierDeviceConfigChanged();
-
-    // The downloader choose: zh-hant
-    assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hant").getState())
-        .isEqualTo(WorkInfo.State.ENQUEUED);
-
-    assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh")).isNull();
-    assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hk")).isNull();
-    assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hans-hk")).isNull();
-    assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "es-hant-hk")).isNull();
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_useUniversalModelIfNoMatchedTag()
-      throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG_2, MANIFEST_URL);
-    setUpManifestUrl(MODEL_TYPE, LOCALE_UNIVERSAL_TAG, MANIFEST_URL_2);
-    downloadManager.onTextClassifierDeviceConfigChanged();
-
-    assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG_2)).isNull();
-
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_UNIVERSAL_TAG);
-    assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-    assertThat(workInfo.getTags()).contains(MANIFEST_URL_2);
-  }
-
-  @Test
-  public void onTextClassifierDeviceConfigChanged_logAfterDownloadScheduled() throws Exception {
-    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
-    downloadManager.onTextClassifierDeviceConfigChanged();
-
-    WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, LOCALE_TAG);
-    assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
-
-    // verify log
-    TextClassifierDownloadReported atom = Iterables.getOnlyElement(loggerTestRule.getLoggedAtoms());
-    assertThat(atom.getDownloadStatus())
-        .isEqualTo(TextClassifierDownloadReported.DownloadStatus.SCHEDULED);
-    assertThat(atom.getModelType()).isEqualTo(MODEL_TYPE_ATOM);
-    assertThat(atom.getUrlSuffix()).isEqualTo(MANIFEST_URL);
-  }
-
-  private void setUpManifestUrl(
-      @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
-    String deviceConfigFlag =
-        String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
-    deviceConfig.setConfig(deviceConfigFlag, url);
-  }
-
-  /** One unique queue holds at most one request at one time. Returns null if no WorkInfo found. */
-  private WorkInfo queryTheOnlyWorkInfo(@ModelType.ModelTypeDef String modelType, String localeTag)
-      throws Exception {
-    return DownloaderTestUtils.queryTheOnlyWorkInfo(
-        workManager, ModelDownloadManager.formatUniqueWorkName(modelType, localeTag));
+    assertThat(downloadManager.listDownloadedModels(MODEL_TYPE)).containsExactly(modelFile);
   }
 }
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
index 8c6d13f..47e7fb6 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/ModelDownloaderImplTest.java
@@ -24,10 +24,10 @@
 import androidx.test.core.app.ApplicationProvider;
 import com.android.textclassifier.downloader.TestModelDownloaderService.DownloadResult;
 import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
 import java.io.File;
 import java.nio.file.Files;
 import java.util.concurrent.CancellationException;
-import java.util.concurrent.Executors;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -53,18 +53,23 @@
       ModelManifest.newBuilder().addModels(MODEL_PROTO).build();
 
   private ModelDownloaderImpl modelDownloaderImpl;
+  private File modelDownloaderDir;
 
   @Before
   public void setUp() {
     Context context = ApplicationProvider.getApplicationContext();
     this.modelDownloaderImpl =
         new ModelDownloaderImpl(
-            context, Executors.newSingleThreadExecutor(), TestModelDownloaderService.class);
+            context, MoreExecutors.newDirectExecutorService(), TestModelDownloaderService.class);
+    this.modelDownloaderDir = new File(context.getFilesDir(), "downloader");
+    this.modelDownloaderDir.mkdirs();
+
+    TestModelDownloaderService.reset();
   }
 
   @After
   public void tearDown() {
-    TestModelDownloaderService.reset();
+    DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
   }
 
   @Test
@@ -148,9 +153,12 @@
     assertThat(TestModelDownloaderService.isBound()).isFalse();
 
     TestModelDownloaderService.setBindSucceed(true);
-    TestModelDownloaderService.setDownloadResult(MANIFEST_URL, DownloadResult.BLOCKING, null);
+    TestModelDownloaderService.setDownloadResult(MANIFEST_URL, DownloadResult.DO_NOTHING, null);
     ListenableFuture<ModelManifest> manifestFuture =
         modelDownloaderImpl.downloadManifest(MANIFEST_URL);
+
+    assertThat(TestModelDownloaderService.getOnBindInvokedLatch().await(1L, SECONDS)).isTrue();
+    assertThat(TestModelDownloaderService.isBound()).isTrue();
     manifestFuture.cancel(true);
 
     expectThrows(CancellationException.class, manifestFuture::get);
@@ -165,7 +173,8 @@
     assertThat(TestModelDownloaderService.isBound()).isFalse();
 
     TestModelDownloaderService.setBindSucceed(false);
-    ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+    ListenableFuture<File> modelFuture =
+        modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
 
     Throwable t = expectThrows(Throwable.class, modelFuture::get);
     assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
@@ -184,9 +193,11 @@
     TestModelDownloaderService.setBindSucceed(true);
     TestModelDownloaderService.setDownloadResult(
         MODEL_URL, DownloadResult.SUCCEEDED, MODEL_CONTENT_BYTES);
-    ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+    ListenableFuture<File> modelFuture =
+        modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
 
     File modelFile = modelFuture.get();
+    assertThat(modelFile.getParentFile()).isEqualTo(modelDownloaderDir);
     assertThat(Files.readAllBytes(modelFile.toPath())).isEqualTo(MODEL_CONTENT_BYTES);
     assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
     assertThat(TestModelDownloaderService.isBound()).isFalse();
@@ -200,7 +211,8 @@
 
     TestModelDownloaderService.setBindSucceed(true);
     TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.FAILED, null);
-    ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+    ListenableFuture<File> modelFuture =
+        modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
 
     Throwable t = expectThrows(Throwable.class, modelFuture::get);
     assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
@@ -220,7 +232,8 @@
     TestModelDownloaderService.setBindSucceed(true);
     TestModelDownloaderService.setDownloadResult(
         MODEL_URL, DownloadResult.SUCCEEDED, "randomString".getBytes());
-    ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+    ListenableFuture<File> modelFuture =
+        modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
 
     Throwable t = expectThrows(Throwable.class, modelFuture::get);
     assertThat(t).hasCauseThat().isInstanceOf(ModelDownloadException.class);
@@ -237,8 +250,12 @@
     assertThat(TestModelDownloaderService.isBound()).isFalse();
 
     TestModelDownloaderService.setBindSucceed(true);
-    TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.BLOCKING, null);
-    ListenableFuture<File> modelFuture = modelDownloaderImpl.downloadModel(MODEL_PROTO);
+    TestModelDownloaderService.setDownloadResult(MODEL_URL, DownloadResult.DO_NOTHING, null);
+    ListenableFuture<File> modelFuture =
+        modelDownloaderImpl.downloadModel(modelDownloaderDir, MODEL_PROTO);
+
+    assertThat(TestModelDownloaderService.getOnBindInvokedLatch().await(1L, SECONDS)).isTrue();
+    assertThat(TestModelDownloaderService.isBound()).isTrue();
     modelFuture.cancel(true);
 
     expectThrows(CancellationException.class, modelFuture::get);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java
index 34b6660..4a2741b 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/NewModelDownloadWorkerTest.java
@@ -17,10 +17,13 @@
 package com.android.textclassifier.downloader;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import android.content.Context;
+import android.os.LocaleList;
+import androidx.room.Room;
 import androidx.test.core.app.ApplicationProvider;
 import androidx.work.ListenableWorker;
 import androidx.work.WorkerFactory;
@@ -30,11 +33,17 @@
 import com.android.os.AtomsProto.TextClassifierDownloadReported.DownloadStatus;
 import com.android.os.AtomsProto.TextClassifierDownloadReported.FailureReason;
 import com.android.textclassifier.common.ModelType;
+import com.android.textclassifier.common.TextClassifierSettings;
 import com.android.textclassifier.common.statsd.TextClassifierDownloadLoggerTestRule;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.android.textclassifier.testing.TestingDeviceConfig;
 import com.google.common.collect.Iterables;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.MoreExecutors;
 import java.io.File;
+import java.util.List;
+import java.util.Locale;
+import java.util.stream.Collectors;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
@@ -47,35 +56,54 @@
 @RunWith(JUnit4.class)
 public final class NewModelDownloadWorkerTest {
   private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+  private static final String MODEL_TYPE_2 = ModelType.ACTIONS_SUGGESTIONS;
   private static final TextClassifierDownloadReported.ModelType MODEL_TYPE_ATOM =
       TextClassifierDownloadReported.ModelType.ANNOTATOR;
   private static final String LOCALE_TAG = "en";
   private static final String MANIFEST_URL =
       "https://www.gstatic.com/android/text_classifier/q/v711/en.fb.manifest";
+  private static final String MANIFEST_URL_2 =
+      "https://www.gstatic.com/android/text_classifier/q/v711/zh.fb.manifest";
   private static final String MODEL_URL =
       "https://www.gstatic.com/android/text_classifier/q/v711/en.fb";
+  private static final String MODEL_URL_2 =
+      "https://www.gstatic.com/android/text_classifier/q/v711/zh.fb";
   private static final int RUN_ATTEMPT_COUNT = 1;
-  private static final int MAX_RUN_ATTEMPT_COUNT = 5;
+  private static final int WORKER_MAX_RUN_ATTEMPT_COUNT = 5;
+  private static final int MANIFEST_MAX_ATTEMPT_COUNT = 2;
   private static final ModelManifest.Model MODEL_PROTO =
       ModelManifest.Model.newBuilder()
           .setUrl(MODEL_URL)
           .setSizeInBytes(1)
           .setFingerprint("fingerprint")
           .build();
+  private static final ModelManifest.Model MODEL_PROTO_2 =
+      ModelManifest.Model.newBuilder()
+          .setUrl(MODEL_URL_2)
+          .setSizeInBytes(1)
+          .setFingerprint("fingerprint")
+          .build();
   private static final ModelManifest MODEL_MANIFEST_PROTO =
       ModelManifest.newBuilder().addModels(MODEL_PROTO).build();
+  private static final ModelManifest MODEL_MANIFEST_PROTO_2 =
+      ModelManifest.newBuilder().addModels(MODEL_PROTO_2).build();
   private static final ModelDownloadException FAILED_TO_DOWNLOAD_EXCEPTION =
       new ModelDownloadException(
           ModelDownloadException.FAILED_TO_DOWNLOAD_OTHER, "failed to download");
   private static final FailureReason FAILED_TO_DOWNLOAD_FAILURE_REASON =
       TextClassifierDownloadReported.FailureReason.FAILED_TO_DOWNLOAD_OTHER;
-
-  private File downloadedModelDir;
-  private File pendingModelDir;
-  private File targetModelFile;
+  private static final LocaleList DEFAULT_LOCALE_LIST = new LocaleList(new Locale(LOCALE_TAG));
 
   @Mock private ModelDownloader modelDownloader;
-  @Mock private Runnable postDownloadCleanUpRunnable;
+  private File modelDownloaderDir;
+  private File modelFile;
+  private File modelFile2;
+  private DownloadedModelDatabase db;
+  private DownloadedModelManager downloadedModelManager;
+  private TestingDeviceConfig deviceConfig;
+  private TextClassifierSettings settings;
+
+  @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
 
   @Rule
   public final TextClassifierDownloadLoggerTestRule loggerTestRule =
@@ -86,133 +114,219 @@
     MockitoAnnotations.initMocks(this);
 
     Context context = ApplicationProvider.getApplicationContext();
-    this.downloadedModelDir = new File(context.getCacheDir(), "downloaded");
-    this.downloadedModelDir.mkdirs();
-    this.pendingModelDir = new File(context.getCacheDir(), "pending");
-    this.pendingModelDir.mkdirs();
-    this.targetModelFile =
-        new File(
-            downloadedModelDir,
-            NewModelDownloadWorker.formatFileNameByModelTypeAndUrl(MODEL_TYPE, MODEL_URL));
-    this.targetModelFile.delete();
+    this.deviceConfig = new TestingDeviceConfig();
+    this.settings = new TextClassifierSettings(deviceConfig);
+    this.modelDownloaderDir = new File(context.getCacheDir(), "downloaded");
+    this.modelDownloaderDir.mkdirs();
+    this.modelFile = new File(modelDownloaderDir, "test.model");
+    this.modelFile2 = new File(modelDownloaderDir, "test2.model");
+    this.db = Room.inMemoryDatabaseBuilder(context, DownloadedModelDatabase.class).build();
+    this.downloadedModelManager =
+        DownloadedModelManagerImpl.getInstanceForTesting(db, modelDownloaderDir, settings);
+
+    setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
+    deviceConfig.setConfig(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED, true);
   }
 
   @After
   public void cleanUp() {
-    DownloaderTestUtils.deleteRecursively(downloadedModelDir);
-    DownloaderTestUtils.deleteRecursively(pendingModelDir);
+    db.close();
+    DownloaderTestUtils.deleteRecursively(modelDownloaderDir);
   }
 
   @Test
-  public void downloadManifest_succeed_downloadModel_succeed_moveModelFile_succeed()
-      throws Exception {
+  public void downloadSucceed() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
     when(modelDownloader.downloadManifest(MANIFEST_URL))
         .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
-    File pendingModelFile = new File(pendingModelDir, "pending.model.file");
-    pendingModelFile.createNewFile();
-    when(modelDownloader.downloadModel(MODEL_PROTO))
-        .thenReturn(Futures.immediateFuture(pendingModelFile));
+    modelFile.createNewFile();
+    when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+        .thenReturn(Futures.immediateFuture(modelFile));
 
     NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
     assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
-    assertThat(targetModelFile.exists()).isTrue();
-    assertThat(pendingModelFile.exists()).isFalse();
-    verify(postDownloadCleanUpRunnable).run();
+    assertThat(modelFile.exists()).isTrue();
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
     verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
   }
 
   @Test
-  public void downloadManifest_failed() throws Exception {
+  public void downloadSucceed_modelAlreadyExists() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    when(modelDownloader.downloadManifest(MANIFEST_URL))
+        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+    modelFile.createNewFile();
+    downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+
+    NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+    assertThat(modelFile.exists()).isTrue();
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+    verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
+  }
+
+  @Test
+  public void downloadSucceed_manifestAlreadyExists() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    modelFile.createNewFile();
+    downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+    downloadedModelManager.registerManifest(MANIFEST_URL, MODEL_URL);
+
+    NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+    assertThat(modelFile.exists()).isTrue();
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+    verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
+  }
+
+  @Test
+  public void downloadSucceed_downloadMultipleModels() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+    when(modelDownloader.downloadManifest(MANIFEST_URL))
+        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+    when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO_2));
+    modelFile.createNewFile();
+    modelFile2.createNewFile();
+    when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+        .thenReturn(Futures.immediateFuture(modelFile));
+    when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO_2))
+        .thenReturn(Futures.immediateFuture(modelFile2));
+
+    NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+    assertThat(modelFile.exists()).isTrue();
+    assertThat(modelFile2.exists()).isTrue();
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile2);
+    List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedAtoms();
+    assertThat(atoms).hasSize(2);
+    assertThat(
+            atoms.stream()
+                .map(TextClassifierDownloadReported::getUrlSuffix)
+                .collect(Collectors.toList()))
+        .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+    assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+    assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+  }
+
+  @Test
+  public void downloadSucceed_shareSingleModelDownloadForMultipleManifest() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL_2);
+    when(modelDownloader.downloadManifest(MANIFEST_URL))
+        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+    when(modelDownloader.downloadManifest(MANIFEST_URL_2))
+        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+    modelFile.createNewFile();
+    when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+        .thenReturn(Futures.immediateFuture(modelFile));
+
+    NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+    assertThat(modelFile.exists()).isTrue();
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile);
+    verify(modelDownloader, times(1)).downloadModel(modelDownloaderDir, MODEL_PROTO);
+    List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedAtoms();
+    assertThat(atoms).hasSize(2);
+    assertThat(
+            atoms.stream()
+                .map(TextClassifierDownloadReported::getUrlSuffix)
+                .collect(Collectors.toList()))
+        .containsExactly(MANIFEST_URL, MANIFEST_URL_2);
+    assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+    assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+  }
+
+  @Test
+  public void downloadSucceed_shareManifest() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    setUpManifestUrl(MODEL_TYPE_2, LOCALE_TAG, MANIFEST_URL);
+    when(modelDownloader.downloadManifest(MANIFEST_URL))
+        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
+    modelFile.createNewFile();
+    when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
+        .thenReturn(Futures.immediateFuture(modelFile));
+
+    NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
+    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+    assertThat(modelFile.exists()).isTrue();
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE)).containsExactly(modelFile);
+    assertThat(downloadedModelManager.listModels(MODEL_TYPE_2)).containsExactly(modelFile);
+    verify(modelDownloader, times(1)).downloadManifest(MANIFEST_URL);
+    List<TextClassifierDownloadReported> atoms = loggerTestRule.getLoggedAtoms();
+    assertThat(atoms).hasSize(2);
+    assertThat(
+            atoms.stream()
+                .map(TextClassifierDownloadReported::getUrlSuffix)
+                .collect(Collectors.toList()))
+        .containsExactly(MANIFEST_URL, MANIFEST_URL);
+    assertThat(atoms.get(0).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+    assertThat(atoms.get(1).getDownloadStatus()).isEqualTo(DownloadStatus.SUCCEEDED);
+  }
+
+  @Test
+  public void downloadFailed_failedToDownloadManifest() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
     when(modelDownloader.downloadManifest(MANIFEST_URL))
         .thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
 
     NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
     assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
-    assertThat(targetModelFile.exists()).isFalse();
-    verify(postDownloadCleanUpRunnable).run();
     verifyLoggedAtom(
         DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FAILED_TO_DOWNLOAD_FAILURE_REASON);
   }
 
   @Test
-  public void downloadManifest_succeed_downloadModel_failed() throws Exception {
+  public void downloadFailed_failedToDownloadModel() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
     when(modelDownloader.downloadManifest(MANIFEST_URL))
         .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
-    when(modelDownloader.downloadModel(MODEL_PROTO))
+    when(modelDownloader.downloadModel(modelDownloaderDir, MODEL_PROTO))
         .thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
 
     NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
     assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
-    assertThat(targetModelFile.exists()).isFalse();
-    verify(postDownloadCleanUpRunnable).run();
     verifyLoggedAtom(
         DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FAILED_TO_DOWNLOAD_FAILURE_REASON);
   }
 
   @Test
-  public void downloadManifest_succeed_downloadModel_alreadyExist() throws Exception {
-    when(modelDownloader.downloadManifest(MANIFEST_URL))
-        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
-    when(modelDownloader.downloadModel(MODEL_PROTO))
-        .thenReturn(Futures.immediateFailedFuture(FAILED_TO_DOWNLOAD_EXCEPTION));
-    targetModelFile.createNewFile();
-
-    NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
-    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
-    assertThat(targetModelFile.exists()).isTrue();
-    verify(postDownloadCleanUpRunnable).run();
-    verifyLoggedAtom(DownloadStatus.SUCCEEDED, RUN_ATTEMPT_COUNT, /* failureReason= */ null);
-  }
-
-  @Test
-  public void downloadManifest_succeed_downloadModel_succeed_moveModelFile_failed()
-      throws Exception {
-    when(modelDownloader.downloadManifest(MANIFEST_URL))
-        .thenReturn(Futures.immediateFuture(MODEL_MANIFEST_PROTO));
-    File pendingModelFile = new File(pendingModelDir, "pending.model.file");
-    pendingModelFile.createNewFile();
-    when(modelDownloader.downloadModel(MODEL_PROTO))
-        .thenReturn(Futures.immediateFuture(pendingModelFile));
-
-    try {
-      downloadedModelDir.setWritable(false);
-
-      NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
-      assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
-      assertThat(targetModelFile.exists()).isFalse();
-      assertThat(pendingModelFile.exists()).isFalse();
-      verify(postDownloadCleanUpRunnable).run();
-      verifyLoggedAtom(
-          DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FailureReason.FAILED_TO_MOVE_MODEL);
-    } finally {
-      downloadedModelDir.setWritable(true);
-    }
-  }
-
-  @Test
-  public void reachMaxRunAttempts() throws Exception {
-    NewModelDownloadWorker worker = createWorker(MAX_RUN_ATTEMPT_COUNT);
+  public void downloadFailed_reachWorkerMaxRunAttempts() throws Exception {
+    NewModelDownloadWorker worker = createWorker(WORKER_MAX_RUN_ATTEMPT_COUNT);
     assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.failure());
-    verifyLoggedAtom(
-        DownloadStatus.FAILED_AND_ABORT,
-        MAX_RUN_ATTEMPT_COUNT,
-        FailureReason.UNKNOWN_FAILURE_REASON);
   }
 
   @Test
-  public void workerStopped() throws Exception {
-    NewModelDownloadWorker worker = createWorker(RUN_ATTEMPT_COUNT);
-    worker.onStopped();
-    verifyLoggedAtom(
-        DownloadStatus.FAILED_AND_RETRY, RUN_ATTEMPT_COUNT, FailureReason.WORKER_STOPPED);
+  public void downloadSkipped_reachManifestMaxAttempts() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    // Current default max attempts is 2
+    downloadedModelManager.registerManifestDownloadFailure(MANIFEST_URL);
+    downloadedModelManager.registerManifestDownloadFailure(MANIFEST_URL);
+
+    NewModelDownloadWorker worker = createWorker(MANIFEST_MAX_ATTEMPT_COUNT);
+    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+    assertThat(loggerTestRule.getLoggedAtoms()).isEmpty();
+  }
+
+  @Test
+  public void downloadSkipped_manifestAlreadyProcessed() throws Exception {
+    setUpManifestUrl(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+    modelFile.createNewFile();
+    downloadedModelManager.registerModel(MODEL_URL, modelFile.getAbsolutePath());
+    downloadedModelManager.registerManifest(MANIFEST_URL, MODEL_URL);
+    downloadedModelManager.registerManifestEnrollment(MODEL_TYPE, LOCALE_TAG, MANIFEST_URL);
+
+    NewModelDownloadWorker worker = createWorker(MANIFEST_MAX_ATTEMPT_COUNT);
+    assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+    assertThat(loggerTestRule.getLoggedAtoms()).isEmpty();
   }
 
   private NewModelDownloadWorker createWorker(int runAttemptCount) {
     return TestListenableWorkerBuilder.from(
             ApplicationProvider.getApplicationContext(), NewModelDownloadWorker.class)
-        .setInputData(
-            NewModelDownloadWorker.createInputData(
-                MODEL_TYPE, LOCALE_TAG, MANIFEST_URL, MAX_RUN_ATTEMPT_COUNT))
         .setRunAttemptCount(runAttemptCount)
         .setWorkerFactory(
             new WorkerFactory() {
@@ -224,8 +338,8 @@
                     workerParameters,
                     MoreExecutors.newDirectExecutorService(),
                     modelDownloader,
-                    downloadedModelDir,
-                    postDownloadCleanUpRunnable);
+                    downloadedModelManager,
+                    settings);
               }
             })
         .build();
@@ -243,4 +357,11 @@
       assertThat(atom.getFailureReason()).isEqualTo(failureReason);
     }
   }
+
+  private void setUpManifestUrl(
+      @ModelType.ModelTypeDef String modelType, String localeTag, String url) {
+    String deviceConfigFlag =
+        String.format(TextClassifierSettings.MANIFEST_URL_TEMPLATE, modelType, localeTag);
+    deviceConfig.setConfig(deviceConfigFlag, url);
+  }
 }
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
index 172fa68..97a55fb 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/downloader/TestModelDownloaderService.java
@@ -25,6 +25,7 @@
 import java.util.concurrent.CountDownLatch;
 import javax.annotation.Nullable;
 
+// TODO(licha): Find another way to test the service. Those static states can break easily.
 /** Test Service of IModelDownloaderService. */
 public final class TestModelDownloaderService extends Service {
   private static final String TAG = "TestModelDownloaderService";
@@ -38,13 +39,13 @@
   public enum DownloadResult {
     SUCCEEDED,
     FAILED,
-    BLOCKING,
     DO_NOTHING
   }
 
   // Obviously this does not work when considering concurrency, but probably fine for test purpose
   private static boolean boundBefore = false;
   private static boolean boundNow = false;
+  private static CountDownLatch onBindInvokedLatch = new CountDownLatch(1);
   private static CountDownLatch onUnbindInvokedLatch = new CountDownLatch(1);
 
   private static boolean bindSucceed = false;
@@ -60,6 +61,10 @@
     return boundNow;
   }
 
+  public static CountDownLatch getOnBindInvokedLatch() {
+    return onBindInvokedLatch;
+  }
+
   public static CountDownLatch getOnUnbindInvokedLatch() {
     return onUnbindInvokedLatch;
   }
@@ -78,26 +83,34 @@
   public static void reset() {
     boundBefore = false;
     boundNow = false;
+    onBindInvokedLatch = new CountDownLatch(1);
     onUnbindInvokedLatch = new CountDownLatch(1);
     bindSucceed = false;
   }
 
   @Override
   public IBinder onBind(Intent intent) {
-    if (bindSucceed) {
-      boundBefore = true;
-      boundNow = true;
-      return new TestModelDownloaderServiceImpl();
-    } else {
-      return null;
+    try {
+      if (bindSucceed) {
+        boundBefore = true;
+        boundNow = true;
+        return new TestModelDownloaderServiceImpl();
+      } else {
+        return null;
+      }
+    } finally {
+      onBindInvokedLatch.countDown();
     }
   }
 
   @Override
   public boolean onUnbind(Intent intent) {
-    boundNow = false;
-    onUnbindInvokedLatch.countDown();
-    return false;
+    try {
+      boundNow = false;
+      return false;
+    } finally {
+      onUnbindInvokedLatch.countDown();
+    }
   }
 
   private static final class TestModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
@@ -118,10 +131,6 @@
           case FAILED:
             callback.onFailure(ERROR_CODE, ERROR_MSG);
             break;
-          case BLOCKING:
-            Thread.sleep(1000 * 60L);
-            TcLog.w(TAG, "Blocking request returns.");
-            break;
           case DO_NOTHING:
             // Do nothing
         }