Import TextClassifier to AOSP.
This is the first time to introduce downloader code to AOSP.
Test: m TextClassifierService & atest -p external/libtextclassifier
Bug: 180441353
Change-Id: I8b69daced97760fb0f3b0cc882107c39e97eb0a0
diff --git a/java/Android.bp b/java/Android.bp
index ca34a66..30fd2bc 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -52,7 +52,10 @@
// Similar to TextClassifierServiceLib, but without the AndroidManifest.
android_library {
name: "TextClassifierServiceLibNoManifest",
- srcs: ["src/**/*.java"],
+ srcs: [
+ "src/**/*.java",
+ "src/**/*.aidl",
+ ],
manifest: "LibNoManifest_AndroidManifest.xml",
static_libs: [
"androidx.core_core",
@@ -61,6 +64,11 @@
"guava",
"textclassifier-statsd",
"error_prone_annotations",
+ "androidx.work_work-runtime",
+ "android_downloader_lib",
+ "textclassifier-statsd",
+ "textclassifier-java-proto-lite",
+ "androidx.concurrent_concurrent-futures",
],
sdk_version: "system_current",
min_sdk_version: "30",
diff --git a/java/AndroidManifest.xml b/java/AndroidManifest.xml
index f2dfcb7..083991c 100644
--- a/java/AndroidManifest.xml
+++ b/java/AndroidManifest.xml
@@ -33,9 +33,20 @@
<uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
<uses-permission android:name="android.permission.RECEIVE_BOOT_COMPLETED" />
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE"/>
- <uses-permission android:name="android.permission.INTERNET" />
+
+ <!-- The INTERNET permission is restricted to the modelDownloaderServiceProcess -->
+ <uses-permission android:name="android.permission.INTERNET"/>
<application>
+ <processes>
+ <process>
+ <deny-permission android:name="android.permission.INTERNET" />
+ </process>
+ <process android:process=":modelDownloaderServiceProcess">
+ <allow-permission android:name="android.permission.INTERNET" />
+ </process>
+ </processes>
+
<service
android:exported="true"
android:name=".DefaultTextClassifierService"
@@ -44,6 +55,11 @@
<action android:name="android.service.textclassifier.TextClassifierService"/>
</intent-filter>
</service>
-
+ <service
+ android:exported="false"
+ android:name=".ModelDownloaderService"
+ android:process=":modelDownloaderServiceProcess">
+ </service>
</application>
+
</manifest>
diff --git a/java/src/com/android/textclassifier/AbstractDownloadWorker.java b/java/src/com/android/textclassifier/AbstractDownloadWorker.java
new file mode 100644
index 0000000..43150fc
--- /dev/null
+++ b/java/src/com/android/textclassifier/AbstractDownloadWorker.java
@@ -0,0 +1,176 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.content.Context;
+import androidx.work.Data;
+import androidx.work.ListenableWorker;
+import androidx.work.WorkerParameters;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import java.io.File;
+import java.net.URI;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+/**
+ * Abstract worker to download specified manifest/model. Subclasses only need to implement the logic
+ * to handle the downloaded file. Scheduled/executed by WorkManager.
+ */
+abstract class AbstractDownloadWorker extends ListenableWorker {
+ private static final String TAG = "DownloadWorker";
+
+ @VisibleForTesting static final String DATA_URL_KEY = "DownloadWorker_url";
+
+ @VisibleForTesting
+ static final String DATA_DESTINATION_PATH_KEY = "DownloadWorker_destinationPath";
+
+ @VisibleForTesting
+ static final String DATA_REUSE_EXISTING_FILE_KEY = "DownloadWorker_reuseExistingFile";
+
+ @VisibleForTesting
+ static final String DATA_MAX_DOWNLOAD_ATTEMPTS_KEY = "DownloadWorker_maxAttempts";
+
+ private static final boolean DATA_REUSE_EXISTING_FILE_DEFAULT = false;
+ private static final int DATA_MAX_DOWNLOAD_ATTEMPTS_DEFAULT = 5;
+
+ private final String url;
+ private final String destinationPath;
+ private final boolean reuseExistingFile;
+ private final int maxDownloadAttempts;
+
+ // TODO(licha): Maybe create some static executors and share them across tcs
+ private final ExecutorService bgExecutorService;
+ private final ModelDownloader downloader;
+
+ AbstractDownloadWorker(Context context, WorkerParameters workerParams) {
+ this(context, workerParams, Executors.newSingleThreadExecutor());
+ }
+
+ private AbstractDownloadWorker(
+ Context context, WorkerParameters workerParams, ExecutorService bgExecutorService) {
+ this(
+ context,
+ workerParams,
+ bgExecutorService,
+ new ModelDownloaderImpl(context, bgExecutorService));
+ }
+
+ @VisibleForTesting
+ AbstractDownloadWorker(
+ Context context,
+ WorkerParameters workerParams,
+ ExecutorService bgExecutorService,
+ ModelDownloader downloader) {
+ super(context, workerParams);
+
+ this.url = Preconditions.checkNotNull(getInputData().getString(DATA_URL_KEY));
+ this.destinationPath =
+ Preconditions.checkNotNull(getInputData().getString(DATA_DESTINATION_PATH_KEY));
+ this.reuseExistingFile =
+ getInputData().getBoolean(DATA_REUSE_EXISTING_FILE_KEY, DATA_REUSE_EXISTING_FILE_DEFAULT);
+ this.maxDownloadAttempts =
+ getInputData().getInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, DATA_MAX_DOWNLOAD_ATTEMPTS_DEFAULT);
+
+ this.bgExecutorService = Preconditions.checkNotNull(bgExecutorService);
+ this.downloader = Preconditions.checkNotNull(downloader);
+ }
+
+ @Override
+ public final ListenableFuture<ListenableWorker.Result> startWork() {
+ TcLog.d(
+ TAG,
+ String.format(
+ "Start download: from %s to %s, attempt:%d",
+ url, destinationPath, getRunAttemptCount()));
+ if (getRunAttemptCount() >= maxDownloadAttempts) {
+ TcLog.d(TAG, "Max attempt reached. Abort download task.");
+ return Futures.immediateFuture(ListenableWorker.Result.failure());
+ }
+
+ File targetFile = new File(destinationPath);
+ ListenableFuture<Long> downloadFuture =
+ (reuseExistingFile && targetFile.exists())
+ ? Futures.immediateFuture(targetFile.length())
+ : downloader.download(URI.create(url), targetFile);
+
+ return FluentFuture.from(downloadFuture)
+ .transform(
+ unusedBytesWritten -> {
+ if (!targetFile.exists()) {
+ throw new IllegalStateException("Download succeeded but target file not found.");
+ }
+ handleDownloadedFile(targetFile);
+ return ListenableWorker.Result.success();
+ },
+ bgExecutorService)
+ .catching(
+ Throwable.class,
+ e -> {
+ TcLog.e(TAG, "Download attempt failed.", e);
+ // Always delete downlaoded file if the work fails.
+ targetFile.delete();
+ // Retry until reach max allowed attempts (attempt starts from 0)
+ // The backoff time between two tries will grow exponentially (i.e. 30s, 1min,
+ // 2min, 4min). This is configurable when building the request.
+ return ListenableWorker.Result.retry();
+ },
+ bgExecutorService);
+ }
+
+ /**
+ * Subclass Workers should override (and only override) this method to handle downloaded file
+ * (e.g. validation, rename). They should throw unchecked Exceptions if failure occurs.
+ */
+ abstract Void handleDownloadedFile(File downloadedFile);
+
+ /**
+ * This method will be called when we our work gets interrupted by the system. Result future
+ * should have already been cancelled in that case. Unless it's because the REPLACE policy of
+ * WorkManager unique queue, the interrupted work will be rescheduled later.
+ */
+ @Override
+ public final void onStopped() {
+ TcLog.d(
+ TAG,
+ String.format(
+ "Stop download: from %s to %s, attempt:%d",
+ url, destinationPath, getRunAttemptCount()));
+ bgExecutorService.shutdown();
+ }
+
+ /**
+ * Helper to create a base input Data builder.
+ *
+ * @param url the URL from where to download content
+ * @param destinationPath the path on the device to store the downlaoded file
+ * @param reuseExistingFile if True, we will skip the download if a file exists in destinationPath
+ * @param maxDownloadAttempts max times to try before we abort this download task
+ */
+ static final Data.Builder createInputDataBuilder(
+ String url, String destinationPath, boolean reuseExistingFile, int maxDownloadAttempts) {
+ return new Data.Builder()
+ .putString(DATA_URL_KEY, url)
+ .putString(DATA_DESTINATION_PATH_KEY, destinationPath)
+ .putBoolean(DATA_REUSE_EXISTING_FILE_KEY, reuseExistingFile)
+ .putInt(DATA_MAX_DOWNLOAD_ATTEMPTS_KEY, maxDownloadAttempts);
+ }
+}
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 26f5a24..ca48a90 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -30,6 +30,7 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.work.WorkManager;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
import com.android.textclassifier.utils.IndentingPrintWriter;
@@ -46,6 +47,7 @@
import java.util.concurrent.Callable;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
+import javax.annotation.Nullable;
/** An implementation of a TextClassifierService. */
public final class DefaultTextClassifierService extends TextClassifierService {
@@ -55,6 +57,9 @@
// TODO: Figure out do we need more concurrency.
private ListeningExecutorService normPriorityExecutor;
private ListeningExecutorService lowPriorityExecutor;
+
+ @Nullable private ModelDownloadManager modelDownloadManager;
+
private TextClassifierImpl textClassifier;
private TextClassifierSettings settings;
private ModelFileManager modelFileManager;
@@ -82,6 +87,17 @@
textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
localeChangedReceiver = new LocaleChangedReceiver(modelFileManager);
+ if (settings.isModelDownloadManagerEnabled()) {
+ modelDownloadManager =
+ new ModelDownloadManager(
+ WorkManager.getInstance(this),
+ ManifestDownloadWorker.class,
+ modelFileManager,
+ settings,
+ lowPriorityExecutor);
+ modelDownloadManager.init();
+ }
+
textClassifierApiUsageLogger =
injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
diff --git a/java/src/com/android/textclassifier/IModelDownloaderCallback.aidl b/java/src/com/android/textclassifier/IModelDownloaderCallback.aidl
new file mode 100644
index 0000000..7f9d7fb
--- /dev/null
+++ b/java/src/com/android/textclassifier/IModelDownloaderCallback.aidl
@@ -0,0 +1,28 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+/**
+ * Callback for download requests from ModelDownloaderImpl to
+ * ModelDownloaderService.
+ */
+oneway interface IModelDownloaderCallback {
+
+ void onSuccess(long bytesWritten);
+
+ void onFailure(String error);
+}
\ No newline at end of file
diff --git a/java/src/com/android/textclassifier/IModelDownloaderService.aidl b/java/src/com/android/textclassifier/IModelDownloaderService.aidl
new file mode 100644
index 0000000..d69f5ca
--- /dev/null
+++ b/java/src/com/android/textclassifier/IModelDownloaderService.aidl
@@ -0,0 +1,33 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import com.android.textclassifier.IModelDownloaderCallback;
+
+/**
+ * ModelDownloaderService binder interface.
+ */
+oneway interface IModelDownloaderService {
+
+ /**
+ * @param url the full url to download model from
+ * @param targetFilePath the absolute file path for the destination file
+ * @param callback callback to notify caller the downloader result
+ */
+ void download(
+ String url, String targetFilePath, IModelDownloaderCallback callback);
+}
\ No newline at end of file
diff --git a/java/src/com/android/textclassifier/ManifestDownloadWorker.java b/java/src/com/android/textclassifier/ManifestDownloadWorker.java
new file mode 100644
index 0000000..f067ccf
--- /dev/null
+++ b/java/src/com/android/textclassifier/ManifestDownloadWorker.java
@@ -0,0 +1,165 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.content.Context;
+import androidx.work.Constraints;
+import androidx.work.Data;
+import androidx.work.ExistingWorkPolicy;
+import androidx.work.ListenableWorker;
+import androidx.work.NetworkType;
+import androidx.work.OneTimeWorkRequest;
+import androidx.work.WorkManager;
+import androidx.work.WorkerParameters;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.common.base.TcLog;
+import com.android.textclassifier.protobuf.ExtensionRegistryLite;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.EnumBiMap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.util.concurrent.Futures;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.time.Duration;
+
+/** Worker to download/parse models' manifest file and schedule the acutal model download task. */
+public final class ManifestDownloadWorker extends AbstractDownloadWorker {
+ private static final String TAG = "ManifestDownloadWorker";
+ private static final String DATA_MODEL_TYPE_KEY = "ManifestDownloadWorker_modelType";
+ private static final String DATA_MODEL_LANGUAGE_TAG_KEY =
+ "ManifestDownloadWorker_modelLanguageTag";
+ private static final String DATA_MANIFEST_URL_KEY = "ManifestDownloadWorker_manifestUrl";
+ private static final String DATA_TARGET_MODEL_PATH_KEY = "ManifestDownloadWorker_targetModelPath";
+
+ private static final EnumBiMap<ModelManifest.NetworkType, NetworkType> NETWORK_TYPE_MAP =
+ EnumBiMap.create(
+ ImmutableMap.of(
+ ModelManifest.NetworkType.UNMETERED, NetworkType.UNMETERED,
+ ModelManifest.NetworkType.METERED, NetworkType.METERED,
+ ModelManifest.NetworkType.NOT_REQUIRED, NetworkType.NOT_REQUIRED,
+ ModelManifest.NetworkType.NOT_ROAMING, NetworkType.NOT_ROAMING,
+ ModelManifest.NetworkType.CONNECTED, NetworkType.CONNECTED));
+
+ private final String modelType;
+ private final String modelLanguageTag;
+ private final String manifestUrl;
+ private final String targetModelPath;
+
+ private final Context context;
+ private final Class<? extends ListenableWorker> modelDownloadWorkerClass;
+ private final WorkManager workManager;
+
+ public ManifestDownloadWorker(Context context, WorkerParameters workerParams) {
+ this(context, workerParams, ModelDownloadWorker.class);
+ }
+
+ @VisibleForTesting
+ ManifestDownloadWorker(
+ Context context,
+ WorkerParameters workerParams,
+ Class<? extends ListenableWorker> modelDownloadWorkerClass) {
+ super(context, workerParams);
+
+ this.modelType = Preconditions.checkNotNull(getInputData().getString(DATA_MODEL_TYPE_KEY));
+ this.modelLanguageTag =
+ Preconditions.checkNotNull(getInputData().getString(DATA_MODEL_LANGUAGE_TAG_KEY));
+ this.manifestUrl = Preconditions.checkNotNull(getInputData().getString(DATA_MANIFEST_URL_KEY));
+ this.targetModelPath =
+ Preconditions.checkNotNull(getInputData().getString(DATA_TARGET_MODEL_PATH_KEY));
+
+ this.context = Preconditions.checkNotNull(context);
+ this.modelDownloadWorkerClass = Preconditions.checkNotNull(modelDownloadWorkerClass);
+ this.workManager = Preconditions.checkNotNull(WorkManager.getInstance(context));
+ }
+
+ @Override
+ public Void handleDownloadedFile(File manifestFile) {
+ TcLog.d(TAG, "Start to parse model manifest: " + manifestFile.getAbsolutePath());
+ ModelManifest modelManifest;
+ try {
+ modelManifest =
+ ModelManifest.parseFrom(
+ new FileInputStream(manifestFile), ExtensionRegistryLite.getEmptyRegistry());
+ } catch (IOException e) {
+ throw new IllegalStateException("Failed to parse the manifest file.", e);
+ }
+
+ Preconditions.checkState(modelManifest.getModelsCount() == 1);
+ ModelManifest.Model model = modelManifest.getModels(0);
+ Preconditions.checkState(
+ model.getUrl().startsWith(ModelDownloadManager.TEXT_CLASSIFIER_URL_PREFIX));
+ Preconditions.checkState(model.getSizeInBytes() > 0 && !model.getFingerprint().isEmpty());
+
+ File targetModelFile = new File(targetModelPath);
+ File pendingModelFile = new File(context.getCacheDir(), targetModelFile.getName() + ".pending");
+ OneTimeWorkRequest modelDownloadRequest =
+ new OneTimeWorkRequest.Builder(modelDownloadWorkerClass)
+ .setInputData(
+ ModelDownloadWorker.createInputData(
+ model.getUrl(),
+ model.getSizeInBytes(),
+ model.getFingerprint(),
+ manifestFile.getAbsolutePath(),
+ pendingModelFile.getAbsolutePath(),
+ targetModelPath,
+ /* maxDownloadAttempts= */ 5,
+ /* reuseExistingModelFile= */ false))
+ .addTag(manifestUrl)
+ .setConstraints(
+ new Constraints.Builder()
+ .setRequiredNetworkType(
+ NETWORK_TYPE_MAP.get(modelManifest.getRequiredNetworkType()))
+ .setRequiresBatteryNotLow(modelManifest.getRequiresBatteryNotLow())
+ .setRequiresCharging(modelManifest.getRequiresCharging())
+ .setRequiresDeviceIdle(modelManifest.getRequiresDeviceIdle())
+ .setRequiresStorageNotLow(modelManifest.getRequiresStorageNotLow())
+ .build())
+ .keepResultsForAtLeast(
+ Duration.ofDays(ModelDownloadManager.DAYS_TO_KEEP_THE_DOWNLOAD_RESULT))
+ .build();
+
+ // Enqueue chained requests to a unique queue (different from the manifest queue)
+ Futures.getUnchecked(
+ workManager
+ .enqueueUniqueWork(
+ ModelDownloadManager.getModelUniqueWorkName(modelType, modelLanguageTag),
+ ExistingWorkPolicy.REPLACE,
+ modelDownloadRequest)
+ .getResult());
+ return null;
+ }
+
+ /** Creates input Data for a ManifestDownloadWorker. */
+ public static Data createInputData(
+ @ModelType.ModelTypeDef String modelType,
+ String modelLanguageTag,
+ String manifestUrl,
+ String targetManifestPath,
+ String targetModelPath,
+ int maxDownloadAttempts,
+ boolean reuseExistingManifestFile) {
+ return AbstractDownloadWorker.createInputDataBuilder(
+ manifestUrl, targetManifestPath, reuseExistingManifestFile, maxDownloadAttempts)
+ .putString(DATA_MODEL_TYPE_KEY, modelType)
+ .putString(DATA_MODEL_LANGUAGE_TAG_KEY, modelLanguageTag)
+ .putString(DATA_MANIFEST_URL_KEY, manifestUrl)
+ .putString(DATA_TARGET_MODEL_PATH_KEY, targetModelPath)
+ .build();
+ }
+}
diff --git a/java/src/com/android/textclassifier/ModelDownloadManager.java b/java/src/com/android/textclassifier/ModelDownloadManager.java
new file mode 100644
index 0000000..1e7879a
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelDownloadManager.java
@@ -0,0 +1,258 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.os.LocaleList;
+import android.provider.DeviceConfig;
+import android.text.TextUtils;
+import androidx.work.Constraints;
+import androidx.work.ExistingWorkPolicy;
+import androidx.work.ListenableWorker;
+import androidx.work.OneTimeWorkRequest;
+import androidx.work.WorkInfo;
+import androidx.work.WorkManager;
+import androidx.work.WorkQuery;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import java.io.File;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.ExecutionException;
+
+/** Manager to listen to config update and download latest models. */
+final class ModelDownloadManager {
+ private static final String TAG = "ModelDownloadManager";
+
+ static final String UNIVERSAL_MODEL_LANGUAGE_TAG = "universal";
+ static final String TEXT_CLASSIFIER_URL_PREFIX =
+ "https://www.gstatic.com/android/text_classifier/";
+ static final long DAYS_TO_KEEP_THE_DOWNLOAD_RESULT = 28L;
+
+ private final Object lock = new Object();
+
+ private final WorkManager workManager;
+ private final Class<? extends ListenableWorker> manifestDownloadWorkerClass;
+ private final ModelFileManager modelFileManager;
+ private final TextClassifierSettings settings;
+ private final ListeningExecutorService executorService;
+ private final DeviceConfig.OnPropertiesChangedListener deviceConfigListener;
+
+ /**
+ * Constructor for ModelDownloadManager.
+ *
+ * @param workManager singleton WorkManager instance
+ * @param manifestDownloadWorkerClass WorkManager's Worker class to download model manifest and
+ * schedule the actual model download work
+ * @param modelFileManager ModelFileManager to interact with storage layer
+ * @param settings TextClassifierSettings to access DeviceConfig and other settings
+ * @param executorService background executor service
+ */
+ public ModelDownloadManager(
+ WorkManager workManager,
+ Class<? extends ListenableWorker> manifestDownloadWorkerClass,
+ ModelFileManager modelFileManager,
+ TextClassifierSettings settings,
+ ListeningExecutorService executorService) {
+ this.workManager = Preconditions.checkNotNull(workManager);
+ this.manifestDownloadWorkerClass = Preconditions.checkNotNull(manifestDownloadWorkerClass);
+ this.modelFileManager = Preconditions.checkNotNull(modelFileManager);
+ this.settings = Preconditions.checkNotNull(settings);
+ this.executorService = Preconditions.checkNotNull(executorService);
+
+ this.deviceConfigListener =
+ new DeviceConfig.OnPropertiesChangedListener() {
+ @Override
+ public void onPropertiesChanged(DeviceConfig.Properties unused) {
+ // Trigger the check even when the change is unrelated just in case we missed a previous
+ // update
+ checkConfigAndScheduleDownloads();
+ }
+ };
+ }
+
+ /**
+ * Registers a listener to related DeviceConfig flag changes. Will also download models with
+ * {@code executorService} if necessary.
+ */
+ public void init() {
+ DeviceConfig.addOnPropertiesChangedListener(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER, executorService, deviceConfigListener);
+ TcLog.d(TAG, "DeviceConfig listener registered by ModelDownloadManager");
+ // Check flags in background, in case any updates heppened before the TCS starts
+ executorService.execute(this::checkConfigAndScheduleDownloads);
+ }
+
+ /** Un-register the listener to DeviceConfig. */
+ public void destroy() {
+ DeviceConfig.removeOnPropertiesChangedListener(deviceConfigListener);
+ TcLog.d(TAG, "DeviceConfig listener unregistered by ModelDownloadeManager");
+ }
+
+ /**
+ * Check DeviceConfig and schedule new model download requests synchronously. This method is
+ * synchronized and contains blocking operations, only call it in a background executor.
+ */
+ private void checkConfigAndScheduleDownloads() {
+ TcLog.v(TAG, "Checking DeviceConfig to see whether there are new models to download");
+ synchronized (lock) {
+ List<Locale.LanguageRange> languageRanges =
+ Locale.LanguageRange.parse(LocaleList.getAdjustedDefault().toLanguageTags());
+ for (String modelType : ModelType.values()) {
+ // Notice: Be careful of the Locale.lookupTag() matching logic: 1) it will convert the tag
+ // to lower-case only; 2) it matches tags from tail to head and does not allow missing
+ // pieces. E.g. if your system locale is zh-hans-cn, it won't match zh-cn.
+ String bestTag =
+ Locale.lookupTag(
+ languageRanges, settings.getLanguageTagsForManifestURLSuffix(modelType));
+ String modelLanguageTag = bestTag != null ? bestTag : UNIVERSAL_MODEL_LANGUAGE_TAG;
+
+ // One manifest url suffix can uniquely identify a model in the world
+ String manifestUrlSuffix = settings.getManifestURLSuffix(modelType, modelLanguageTag);
+ if (TextUtils.isEmpty(manifestUrlSuffix)) {
+ continue;
+ }
+ String manifestUrl = TEXT_CLASSIFIER_URL_PREFIX + manifestUrlSuffix;
+
+ // Check whether a manifest or a model is in the queue/in the middle of downloading. Both
+ // manifest/model works are tagged with the manifest URL.
+ WorkQuery workQuery =
+ WorkQuery.Builder.fromTags(ImmutableList.of(manifestUrl))
+ .addStates(
+ Arrays.asList(
+ WorkInfo.State.BLOCKED, WorkInfo.State.ENQUEUED, WorkInfo.State.RUNNING))
+ .build();
+ try {
+ List<WorkInfo> workInfos = workManager.getWorkInfos(workQuery).get();
+ if (!workInfos.isEmpty()) {
+ TcLog.v(TAG, "Target model is already in the download queue.");
+ continue;
+ }
+ } catch (ExecutionException | InterruptedException e) {
+ TcLog.e(TAG, "Failed to query queued requests. Ignore and continue.", e);
+ }
+
+ // Target file's name has the url suffix encoded in it
+ File targetModelFile = modelFileManager.getDownloadTargetFile(modelType, manifestUrlSuffix);
+ if (!targetModelFile.getParentFile().exists()) {
+ if (!targetModelFile.getParentFile().mkdirs()) {
+ TcLog.e(TAG, "Failed to create " + targetModelFile.getParentFile().getAbsolutePath());
+ continue;
+ }
+ }
+ // TODO(licha): Ideally, we should verify whether the existing file can be loaded
+ // successfully
+ // Notes: We also don't check factory models and models downloaded by ConfigUpdater. But
+ // this is probablly fine because it's unlikely to have an overlap.
+ if (targetModelFile.exists()) {
+ TcLog.v(TAG, "Target model is already in the storage.");
+ continue;
+ }
+
+ // Skip models downloaded successfully in (at least) past DAYS_TO_KEEP_THE_DOWNLOAD_RESULT
+ // Because we delete less-preferred models after one model downloaded, it's possible that
+ // we fall in a loop (download - delete - download again) if P/H flag is in a bad state.
+ // NOTICE: Because we use an unique work name here, if we download model-1 first and then
+ // model-2, then model-1's WorkInfo will be lost. In that case, if the flag goes back to
+ // model-1, we will download it again even if it's within DAYS_TO_KEEP_THE_DOWNLOAD_RESULT
+ WorkQuery downlaodedBeforeWorkQuery =
+ WorkQuery.Builder.fromTags(ImmutableList.of(manifestUrl))
+ .addStates(ImmutableList.of(WorkInfo.State.SUCCEEDED))
+ .addUniqueWorkNames(
+ ImmutableList.of(getModelUniqueWorkName(modelType, modelLanguageTag)))
+ .build();
+ try {
+ List<WorkInfo> downloadedBeforeWorkInfos =
+ workManager.getWorkInfos(downlaodedBeforeWorkQuery).get();
+ if (!downloadedBeforeWorkInfos.isEmpty()) {
+ TcLog.v(TAG, "The model was downloaded successfully before and got cleaned-up later");
+ continue;
+ }
+ } catch (ExecutionException | InterruptedException e) {
+ TcLog.e(TAG, "Failed to query queued requests. Ignore and continue.", e);
+ }
+
+ String targetModelPath = targetModelFile.getAbsolutePath();
+ String targetManifestPath = getTargetManifestPath(targetModelPath);
+ OneTimeWorkRequest manifestDownloadRequest =
+ new OneTimeWorkRequest.Builder(manifestDownloadWorkerClass)
+ .setInputData(
+ ManifestDownloadWorker.createInputData(
+ modelType,
+ modelLanguageTag,
+ manifestUrl,
+ targetManifestPath,
+ targetModelPath,
+ settings.getModelDownloadMaxAttempts(),
+ /* reuseExistingManifestFile= */ true))
+ .addTag(manifestUrl)
+ .setConstraints(
+ new Constraints.Builder()
+ .setRequiredNetworkType(settings.getManifestDownloadRequiredNetworkType())
+ .setRequiresBatteryNotLow(true)
+ .setRequiresStorageNotLow(true)
+ .build())
+ .keepResultsForAtLeast(Duration.ofDays(DAYS_TO_KEEP_THE_DOWNLOAD_RESULT))
+ .build();
+
+ // When we enqueue a new request, existing pending request in the same queue will be
+ // cancelled. With this, device will be able to abort previous unfinished downloads
+ // (e.g. 711) when a fresher model is already(e.g. v712).
+ try {
+ // Block until we enqueue the request successfully
+ workManager
+ .enqueueUniqueWork(
+ getManifestUniqueWorkName(modelType, modelLanguageTag),
+ ExistingWorkPolicy.REPLACE,
+ manifestDownloadRequest)
+ .getResult()
+ .get();
+ TcLog.d(TAG, "Download scheduled: " + manifestUrl);
+ } catch (ExecutionException | InterruptedException e) {
+ TcLog.e(TAG, "Failed to enqueue a request", e);
+ }
+ }
+ }
+ }
+
+ @VisibleForTesting
+ void checkConfigAndScheduleDownloadsForTesting() {
+ checkConfigAndScheduleDownloads();
+ }
+
+ @VisibleForTesting
+ static String getTargetManifestPath(String targetModelPath) {
+ return targetModelPath + ".manifest";
+ }
+
+ @VisibleForTesting
+ static String getManifestUniqueWorkName(
+ @ModelType.ModelTypeDef String modelType, String modelLanguageTag) {
+ return String.format("manifest-%s-%s", modelType, modelLanguageTag);
+ }
+
+ // ManifestDownloadWorker needs to access this
+ static String getModelUniqueWorkName(
+ @ModelType.ModelTypeDef String modelType, String modelLanguageTag) {
+ return "model-" + modelType + "-" + modelLanguageTag;
+ }
+}
diff --git a/java/src/com/android/textclassifier/ModelDownloadWorker.java b/java/src/com/android/textclassifier/ModelDownloadWorker.java
new file mode 100644
index 0000000..641af8a
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelDownloadWorker.java
@@ -0,0 +1,132 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.content.Context;
+import androidx.work.Data;
+import androidx.work.WorkerParameters;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.hash.HashCode;
+import com.google.common.hash.Hashing;
+import com.google.common.io.Files;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.StandardCopyOption;
+
+/** Worker to download, validate and update model image files. */
+public final class ModelDownloadWorker extends AbstractDownloadWorker {
+ private static final String TAG = "ModelDownloadWorker";
+
+ @VisibleForTesting
+ static final String DATA_MANIFEST_PATH_KEY = "ModelDownloadWorker_manifestPath";
+
+ @VisibleForTesting
+ static final String DATA_TARGET_MODEL_PATH_KEY = "ModelDownloadWorker_targetModelPath";
+
+ @VisibleForTesting
+ static final String DATA_MODEL_SIZE_IN_BYTES_KEY = "ModelDownloadWorker_modelSizeInBytes";
+
+ @VisibleForTesting
+ static final String DATA_MODEL_FINGERPRINT_KEY = "ModelDownloadWorker_modelFingerprint";
+
+ private final String manifestPath;
+ private final String targetModelPath;
+ private final long modelSizeInBytes;
+ private final String modelFingerprint;
+ private final ModelFileManager modelFileManager;
+
+ public ModelDownloadWorker(Context context, WorkerParameters workerParams) {
+ super(context, workerParams);
+ this.manifestPath =
+ Preconditions.checkNotNull(getInputData().getString(DATA_MANIFEST_PATH_KEY));
+ this.targetModelPath =
+ Preconditions.checkNotNull(getInputData().getString(DATA_TARGET_MODEL_PATH_KEY));
+ this.modelSizeInBytes =
+ getInputData().getLong(DATA_MODEL_SIZE_IN_BYTES_KEY, /* defaultValue= */ 0L);
+ this.modelFingerprint =
+ Preconditions.checkNotNull(getInputData().getString(DATA_MODEL_FINGERPRINT_KEY));
+ this.modelFileManager = new ModelFileManager(context, new TextClassifierSettings());
+ }
+
+ @Override
+ public Void handleDownloadedFile(File pendingModelFile) {
+ TcLog.d(TAG, "Start to check pending model file: " + pendingModelFile.getAbsolutePath());
+ try {
+ validateModel(pendingModelFile, modelSizeInBytes, modelFingerprint);
+
+ File targetModelFile = new File(targetModelPath);
+ java.nio.file.Files.move(
+ pendingModelFile.toPath(),
+ targetModelFile.toPath(),
+ StandardCopyOption.ATOMIC_MOVE,
+ StandardCopyOption.REPLACE_EXISTING);
+ TcLog.d(TAG, "Model file downloaded successfully: " + targetModelFile.getAbsolutePath());
+
+ // Clean up manifest and older models
+ new File(manifestPath).delete();
+ modelFileManager.deleteUnusedModelFiles();
+ return null;
+ } catch (Exception e) {
+ throw new IllegalStateException("Failed to validate or move pending model file.", e);
+ } finally {
+ pendingModelFile.delete();
+ }
+ }
+
+ /** Model verification. Throws unchecked Exceptions if validation fails. */
+ private static void validateModel(File pendingModelFile, long sizeInBytes, String fingerprint)
+ throws IOException {
+ if (!pendingModelFile.exists()) {
+ throw new IllegalStateException("PendingModelFile does not exist.");
+ }
+ if (pendingModelFile.length() != sizeInBytes) {
+ throw new IllegalStateException(
+ String.format(
+ "PendingModelFile size does not match: expected [%d] actual [%d]",
+ sizeInBytes, pendingModelFile.length()));
+ }
+ HashCode pendingModelFingerprint = Files.asByteSource(pendingModelFile).hash(Hashing.sha384());
+ if (!pendingModelFingerprint.equals(HashCode.fromString(fingerprint))) {
+ throw new IllegalStateException(
+ String.format(
+ "PendingModelFile fingerprint does not match: expected [%s] actual [%s]",
+ fingerprint, pendingModelFingerprint));
+ }
+ TcLog.d(TAG, "Pending model file passed validation.");
+ }
+
+ /** Creates input Data for a ModelDownloadWorker. */
+ public static Data createInputData(
+ String modelUrl,
+ long modelSizeInBytes,
+ String modelFingerprint,
+ String manifestPath,
+ String pendingModelPath,
+ String targetModelPath,
+ int maxDownloadAttempts,
+ boolean reuseExistingModelFile) {
+ return AbstractDownloadWorker.createInputDataBuilder(
+ modelUrl, pendingModelPath, reuseExistingModelFile, maxDownloadAttempts)
+ .putString(DATA_MANIFEST_PATH_KEY, manifestPath)
+ .putString(DATA_TARGET_MODEL_PATH_KEY, targetModelPath)
+ .putLong(DATA_MODEL_SIZE_IN_BYTES_KEY, modelSizeInBytes)
+ .putString(DATA_MODEL_FINGERPRINT_KEY, modelFingerprint)
+ .build();
+ }
+}
diff --git a/java/src/com/android/textclassifier/ModelDownloader.java b/java/src/com/android/textclassifier/ModelDownloader.java
new file mode 100644
index 0000000..7839a9b
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelDownloader.java
@@ -0,0 +1,39 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import com.google.common.util.concurrent.ListenableFuture;
+import java.io.File;
+import java.net.URI;
+
+/** Interface for downloading files from certain URI. */
+public interface ModelDownloader {
+
+ /**
+ * Downloads a file from the given URI to the target file.
+ *
+ * <p>For a successful download, this method returns a Future containing the number of bytes
+ * written. For a failure case, the Future would fail, with the exception containing more
+ * information. The implementations of this interface should clean up unfinished model files if
+ * the download fails.
+ *
+ * @param uri the URI to download file from
+ * @param targetFile the target File to write the downloaded content. If the file already exists,
+ * its content will be cleared
+ */
+ ListenableFuture<Long> download(URI uri, File targetFile);
+}
diff --git a/java/src/com/android/textclassifier/ModelDownloaderImpl.java b/java/src/com/android/textclassifier/ModelDownloaderImpl.java
new file mode 100644
index 0000000..83eddde
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelDownloaderImpl.java
@@ -0,0 +1,157 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static android.content.Context.BIND_AUTO_CREATE;
+import static android.content.Context.BIND_NOT_FOREGROUND;
+
+import android.content.ComponentName;
+import android.content.Context;
+import android.content.Intent;
+import android.content.ServiceConnection;
+import android.os.IBinder;
+import androidx.concurrent.futures.CallbackToFutureAdapter;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import java.io.File;
+import java.net.URI;
+import java.util.concurrent.ExecutorService;
+
+/**
+ * ModelDownloader implementation that forwards requests to ModelDownloaderService. This is to
+ * restrict the INTERNET permission to the service process only (instead of the whole ExtServices).
+ */
+final class ModelDownloaderImpl implements ModelDownloader {
+ private static final String TAG = "ModelDownloaderImpl";
+
+ private final Context context;
+ private final ExecutorService bgExecutorService;
+ private final Class<?> downloaderServiceClass;
+
+ public ModelDownloaderImpl(Context context, ExecutorService bgExecutorService) {
+ this(context, bgExecutorService, ModelDownloaderService.class);
+ }
+
+ @VisibleForTesting
+ ModelDownloaderImpl(
+ Context context, ExecutorService bgExecutorService, Class<?> downloaderServiceClass) {
+ this.context = context.getApplicationContext();
+ this.bgExecutorService = bgExecutorService;
+ this.downloaderServiceClass = downloaderServiceClass;
+ }
+
+ @Override
+ public ListenableFuture<Long> download(URI uri, File targetFile) {
+ DownloaderServiceConnection conn = new DownloaderServiceConnection();
+ ListenableFuture<IModelDownloaderService> downloaderServiceFuture = connect(conn);
+ ListenableFuture<Long> bytesWrittenFuture =
+ Futures.transformAsync(
+ downloaderServiceFuture,
+ service -> scheduleDownload(service, uri, targetFile),
+ bgExecutorService);
+ bytesWrittenFuture.addListener(
+ () -> {
+ try {
+ context.unbindService(conn);
+ } catch (IllegalArgumentException e) {
+ TcLog.e(TAG, "Error when unbind", e);
+ }
+ },
+ bgExecutorService);
+ return bytesWrittenFuture;
+ }
+
+ private ListenableFuture<IModelDownloaderService> connect(DownloaderServiceConnection conn) {
+ TcLog.d(TAG, "Starting a new connection to ModelDownloaderService");
+ return CallbackToFutureAdapter.getFuture(
+ completer -> {
+ conn.attachCompleter(completer);
+ Intent intent = new Intent(context, downloaderServiceClass);
+ if (context.bindService(intent, conn, BIND_AUTO_CREATE | BIND_NOT_FOREGROUND)) {
+ return "Binding to service";
+ } else {
+ completer.setException(new RuntimeException("Unable to bind to service"));
+ return "Binding failed";
+ }
+ });
+ }
+
+ // Here the returned download result future can be set by: 1) the service can invoke the callback
+ // and set the result/exception; 2) If the service crashed, the CallbackToFutureAdapter will try
+ // to fail the future when the callback is garbage collected. If somehow none of them worked, the
+ // restult future will hang there until time out. (WorkManager forces a 10-min running time.)
+ private static ListenableFuture<Long> scheduleDownload(
+ IModelDownloaderService service, URI uri, File targetFile) {
+ TcLog.d(TAG, "Scheduling a new download task with ModelDownloaderService");
+ return CallbackToFutureAdapter.getFuture(
+ completer -> {
+ service.download(
+ uri.toString(),
+ targetFile.getAbsolutePath(),
+ new IModelDownloaderCallback.Stub() {
+ @Override
+ public void onSuccess(long bytesWritten) {
+ completer.set(bytesWritten);
+ }
+
+ @Override
+ public void onFailure(String errorMsg) {
+ completer.setException(new RuntimeException(errorMsg));
+ }
+ });
+ return "downlaoderService.download";
+ });
+ }
+
+ /** The implementation of {@link ServiceConnection} that handles changes in the connection. */
+ @VisibleForTesting
+ static class DownloaderServiceConnection implements ServiceConnection {
+ private static final String TAG = "ModelDownloaderImpl.DownloaderServiceConnection";
+
+ private CallbackToFutureAdapter.Completer<IModelDownloaderService> completer;
+
+ public void attachCompleter(
+ CallbackToFutureAdapter.Completer<IModelDownloaderService> completer) {
+ this.completer = completer;
+ }
+
+ @Override
+ public void onServiceConnected(ComponentName componentName, IBinder iBinder) {
+ TcLog.d(TAG, "DownloaderService connected");
+ completer.set(Preconditions.checkNotNull(IModelDownloaderService.Stub.asInterface(iBinder)));
+ }
+
+ @Override
+ public void onServiceDisconnected(ComponentName componentName) {
+ // If this is invoked after onServiceConnected, it will be ignored by the completer.
+ completer.setException(new RuntimeException("Service disconnected"));
+ }
+
+ @Override
+ public void onBindingDied(ComponentName name) {
+ completer.setException(new RuntimeException("Binding died"));
+ }
+
+ @Override
+ public void onNullBinding(ComponentName name) {
+ completer.setException(new RuntimeException("Unable to bind to DownloaderService"));
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/ModelDownloaderService.java b/java/src/com/android/textclassifier/ModelDownloaderService.java
new file mode 100644
index 0000000..6fe4ee9
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelDownloaderService.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.app.Service;
+import android.content.Intent;
+import android.os.IBinder;
+import com.android.textclassifier.common.base.TcLog;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+/** Service to expose IModelDownloaderService. */
+public final class ModelDownloaderService extends Service {
+ private static final String TAG = "ModelDownloaderService";
+
+ private ExecutorService executorService;
+ private IBinder iBinder;
+
+ @Override
+ public void onCreate() {
+ super.onCreate();
+ // TODO(licha): Use a shared thread pool for IO intensive tasks
+ this.executorService = Executors.newSingleThreadExecutor();
+ this.iBinder = new ModelDownloaderServiceImpl(executorService);
+ }
+
+ @Override
+ public IBinder onBind(Intent intent) {
+ TcLog.d(TAG, "Binding to ModelDownloadService");
+ return iBinder;
+ }
+
+ @Override
+ public void onDestroy() {
+ TcLog.d(TAG, "Destroying ModelDownloadService");
+ executorService.shutdown();
+ }
+}
diff --git a/java/src/com/android/textclassifier/ModelDownloaderServiceImpl.java b/java/src/com/android/textclassifier/ModelDownloaderServiceImpl.java
new file mode 100644
index 0000000..497beca
--- /dev/null
+++ b/java/src/com/android/textclassifier/ModelDownloaderServiceImpl.java
@@ -0,0 +1,143 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.os.RemoteException;
+import com.android.textclassifier.common.base.TcLog;
+import com.google.android.downloader.AndroidDownloaderLogger;
+import com.google.android.downloader.ConnectivityHandler;
+import com.google.android.downloader.DownloadConstraints;
+import com.google.android.downloader.DownloadRequest;
+import com.google.android.downloader.DownloadResult;
+import com.google.android.downloader.Downloader;
+import com.google.android.downloader.PlatformUrlEngine;
+import com.google.android.downloader.SimpleFileDownloadDestination;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.FutureCallback;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.net.URI;
+import java.util.Collections;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import javax.annotation.concurrent.ThreadSafe;
+
+/** IModelDownloaderService implementation with Android Downloader library. */
+@ThreadSafe
+final class ModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
+ private static final String TAG = "ModelDownloaderServiceImpl";
+
+ // Connectivity constraints will be checked by WorkManager instead.
+ private static class NoOpConnectivityHandler implements ConnectivityHandler {
+ @Override
+ public ListenableFuture<Void> checkConnectivity(DownloadConstraints constraints) {
+ return Futures.immediateVoidFuture();
+ }
+ }
+
+ private final ExecutorService bgExecutorService;
+ private final Downloader downloader;
+
+ public ModelDownloaderServiceImpl(ExecutorService bgExecutorService) {
+ this.bgExecutorService = bgExecutorService;
+ this.downloader =
+ new Downloader.Builder()
+ .withIOExecutor(bgExecutorService)
+ .withConnectivityHandler(new NoOpConnectivityHandler())
+ .addUrlEngine(
+ Collections.singleton("https"),
+ new PlatformUrlEngine(
+ // TODO(licha): use a shared thread pool
+ MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor()),
+ /* connectTimeoutMs= */ 60 * 1000,
+ /* readTimeoutMs= */ 60 * 1000))
+ .withLogger(new AndroidDownloaderLogger())
+ .build();
+ }
+
+ @VisibleForTesting
+ ModelDownloaderServiceImpl(ExecutorService bgExecutorService, Downloader downloader) {
+ this.bgExecutorService = Preconditions.checkNotNull(bgExecutorService);
+ this.downloader = Preconditions.checkNotNull(downloader);
+ }
+
+ @Override
+ public void download(String uri, String targetFilePath, IModelDownloaderCallback callback) {
+ TcLog.d(TAG, "Download request received: " + uri);
+ try {
+ Preconditions.checkArgument(
+ uri.startsWith(ModelDownloadManager.TEXT_CLASSIFIER_URL_PREFIX),
+ "Can only download TextClassifier resources, but uri is: %s",
+ uri);
+ File targetFile = new File(targetFilePath);
+ File tempMetadataFile = getMetadataFile(targetFile);
+ DownloadRequest request =
+ downloader
+ .newRequestBuilder(
+ URI.create(uri), new SimpleFileDownloadDestination(targetFile, tempMetadataFile))
+ .build();
+ downloader
+ .execute(request)
+ .transform(DownloadResult::bytesWritten, MoreExecutors.directExecutor())
+ .addCallback(
+ new FutureCallback<Long>() {
+ @Override
+ public void onSuccess(Long bytesWritten) {
+ tempMetadataFile.delete();
+ dispatchOnSuccessSafely(callback, bytesWritten);
+ }
+
+ @Override
+ public void onFailure(Throwable t) {
+ // TODO(licha): We may be able to resume the download if we keep those files
+ targetFile.delete();
+ tempMetadataFile.delete();
+ dispatchOnFailureSafely(callback, t);
+ }
+ },
+ bgExecutorService);
+ } catch (Throwable t) {
+ dispatchOnFailureSafely(callback, t);
+ }
+ }
+
+ @VisibleForTesting
+ static File getMetadataFile(File targetFile) {
+ return new File(targetFile.getParentFile(), targetFile.getName() + ".metadata");
+ }
+
+ private static void dispatchOnSuccessSafely(
+ IModelDownloaderCallback callback, long bytesWritten) {
+ try {
+ callback.onSuccess(bytesWritten);
+ } catch (RemoteException e) {
+ TcLog.e(TAG, "Unable to notify successful download", e);
+ }
+ }
+
+ private static void dispatchOnFailureSafely(
+ IModelDownloaderCallback callback, Throwable throwable) {
+ try {
+ callback.onFailure(throwable.getMessage());
+ } catch (RemoteException e) {
+ TcLog.e(TAG, "Unable to notify failures in download", e);
+ }
+ }
+}
diff --git a/java/src/com/android/textclassifier/TextClassifierSettings.java b/java/src/com/android/textclassifier/TextClassifierSettings.java
index fa45147..e4ad140 100644
--- a/java/src/com/android/textclassifier/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/TextClassifierSettings.java
@@ -21,7 +21,9 @@
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
import androidx.annotation.NonNull;
+import androidx.work.NetworkType;
import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
@@ -348,6 +350,21 @@
NAMESPACE, MODEL_DOWNLOAD_MANAGER_ENABLED, MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT);
}
+ public NetworkType getManifestDownloadRequiredNetworkType() {
+ String networkType =
+ deviceConfig.getString(
+ NAMESPACE,
+ MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE,
+ MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT);
+ try {
+ return NetworkType.valueOf(networkType);
+ } catch (IllegalArgumentException e) {
+ // In case the flag is not a valid enum value
+ TcLog.w(TAG, "Invalid manifest download required NetworkType: " + networkType);
+ return NetworkType.valueOf(MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT);
+ }
+ }
+
public int getModelDownloadMaxAttempts() {
return deviceConfig.getInt(
NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
@@ -382,7 +399,8 @@
Properties properties = deviceConfig.getProperties(NAMESPACE);
ImmutableList.Builder<String> variantsBuilder = ImmutableList.builder();
for (String name : properties.getKeyset()) {
- if (name.startsWith(urlSuffixFlagBaseName)) {
+ if (name.startsWith(urlSuffixFlagBaseName)
+ && properties.getString(name, /* defaultValue= */ null) != null) {
variantsBuilder.add(name.substring(urlSuffixFlagBaseName.length()));
}
}
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index 871be1e..4ba5d07 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -45,7 +45,8 @@
"TextClassifierServiceLib",
"statsdprotolite",
"textclassifierprotoslite",
- "TextClassifierCoverageLib"
+ "TextClassifierCoverageLib",
+ "androidx.work_work-testing",
],
jni_libs: [
diff --git a/java/tests/instrumentation/AndroidManifest.xml b/java/tests/instrumentation/AndroidManifest.xml
index 3ee30da..e8cf968 100644
--- a/java/tests/instrumentation/AndroidManifest.xml
+++ b/java/tests/instrumentation/AndroidManifest.xml
@@ -8,6 +8,10 @@
<application>
<uses-library android:name="android.test.runner"/>
+ <service
+ android:exported="false"
+ android:name="com.android.textclassifier.TestModelDownloaderService">
+ </service>
</application>
<instrumentation
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/AbstractDownloadWorkerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/AbstractDownloadWorkerTest.java
new file mode 100644
index 0000000..fe3b853
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/AbstractDownloadWorkerTest.java
@@ -0,0 +1,249 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.work.Data;
+import androidx.work.ListenableWorker;
+import androidx.work.WorkerFactory;
+import androidx.work.WorkerParameters;
+import androidx.work.testing.TestListenableWorkerBuilder;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.io.FileWriter;
+import java.net.URI;
+import java.util.concurrent.ExecutorService;
+import java.util.function.Function;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class AbstractDownloadWorkerTest {
+ private static final String URL = "http://www.gstatic.com/android/text_classifier/q/v711/en.fb";
+ private static final String CONTENT_BYTES = "abc";
+ private static final int WORKER_MAX_DOWNLOAD_ATTEMPTS = 5;
+ private static final Function<File, Void> NO_OP_HANDLE_FUNC = f -> null;
+
+ private File targetModelFile;
+
+ @Before
+ public void setUp() {
+ this.targetModelFile =
+ new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb");
+ targetModelFile.deleteOnExit();
+ }
+
+ @Test
+ public void download_succeeded() throws Exception {
+ AbstractDownloadWorker worker =
+ createWorker(
+ createData(/* reuseExistingFile= */ false),
+ /* runAttemptCount= */ 0,
+ TestModelDownloader.withSuccess(CONTENT_BYTES),
+ NO_OP_HANDLE_FUNC);
+ targetModelFile.delete();
+
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(targetModelFile.exists()).isTrue();
+ }
+
+ @Test
+ public void download_reuseExistingFile() throws Exception {
+ AbstractDownloadWorker worker =
+ createWorker(
+ createData(/* reuseExistingFile= */ true),
+ /* runAttemptCount= */ 0,
+ // If we reuse existing file, downloader will not be invoked, thus won't fail
+ TestModelDownloader.withFailure(new Exception()),
+ NO_OP_HANDLE_FUNC);
+ targetModelFile.createNewFile();
+
+ assertThat(targetModelFile.exists()).isTrue();
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(targetModelFile.exists()).isTrue();
+ }
+
+ @Test
+ public void download_reuseExistingFileButNotExist() throws Exception {
+ AbstractDownloadWorker worker =
+ createWorker(
+ createData(/* reuseExistingFile= */ true),
+ /* runAttemptCount= */ 0,
+ TestModelDownloader.withSuccess(CONTENT_BYTES),
+ NO_OP_HANDLE_FUNC);
+ targetModelFile.delete();
+
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(targetModelFile.exists()).isTrue();
+ }
+
+ @Test
+ public void download_reuseExistingFileButNotExistAndFails() throws Exception {
+ AbstractDownloadWorker worker =
+ createWorker(
+ createData(/* reuseExistingFile= */ true),
+ /* runAttemptCount= */ 0,
+ TestModelDownloader.withFailure(new Exception()),
+ NO_OP_HANDLE_FUNC);
+ targetModelFile.delete();
+
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ assertThat(targetModelFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_failedAndRetry() throws Exception {
+ AbstractDownloadWorker worker =
+ createWorker(
+ createData(/* reuseExistingFile= */ false),
+ /* runAttemptCount= */ 0,
+ TestModelDownloader.withFailure(new Exception()),
+ NO_OP_HANDLE_FUNC);
+ targetModelFile.delete();
+
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ assertThat(targetModelFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_failedTooManyAttempts() throws Exception {
+ AbstractDownloadWorker worker =
+ createWorker(
+ createData(/* reuseExistingFile= */ false),
+ WORKER_MAX_DOWNLOAD_ATTEMPTS,
+ TestModelDownloader.withSuccess(CONTENT_BYTES),
+ NO_OP_HANDLE_FUNC);
+ targetModelFile.delete();
+
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.failure());
+ assertThat(targetModelFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_errorWhenHandlingDownloadedFile() throws Exception {
+ AbstractDownloadWorker worker =
+ createWorker(
+ createData(/* reuseExistingFile= */ false),
+ /* runAttemptCount= */ 0,
+ TestModelDownloader.withSuccess(""),
+ file -> {
+ throw new RuntimeException();
+ });
+ targetModelFile.delete();
+
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ // Downlaoded file should be cleaned up if hanlding function fails
+ assertThat(targetModelFile.exists()).isFalse();
+ }
+
+ private Data createData(boolean reuseExistingFile) {
+ return AbstractDownloadWorker.createInputDataBuilder(
+ URL, targetModelFile.getAbsolutePath(), reuseExistingFile, WORKER_MAX_DOWNLOAD_ATTEMPTS)
+ .build();
+ }
+
+ private static AbstractDownloadWorker createWorker(
+ Data data, int runAttemptCount, ModelDownloader downloader, Function<File, Void> handleFunc) {
+ return TestListenableWorkerBuilder.from(
+ ApplicationProvider.getApplicationContext(), TestDownloadWorker.class)
+ .setInputData(data)
+ .setRunAttemptCount(runAttemptCount)
+ .setWorkerFactory(
+ new WorkerFactory() {
+ @Override
+ public ListenableWorker createWorker(
+ Context appContext, String workerClassName, WorkerParameters workerParameters) {
+ return new TestDownloadWorker(
+ appContext,
+ workerParameters,
+ MoreExecutors.newDirectExecutorService(),
+ downloader,
+ handleFunc);
+ }
+ })
+ .build();
+ }
+
+ /** A test AbstractDownloadWorker impl which handles downloaded file with a given Function. */
+ private static class TestDownloadWorker extends AbstractDownloadWorker {
+ private final Function<File, Void> handleFunc;
+
+ TestDownloadWorker(
+ Context context,
+ WorkerParameters workerParameters,
+ ExecutorService bgExecutorService,
+ ModelDownloader modelDownloader,
+ Function<File, Void> handleFunc) {
+ super(context, workerParameters, bgExecutorService, modelDownloader);
+
+ this.handleFunc = handleFunc;
+ }
+
+ @Override
+ Void handleDownloadedFile(File downloadedFile) {
+ return handleFunc.apply(downloadedFile);
+ }
+ }
+
+ /** A ModelDownloader implementation for testing. Set expected resilts in its constructor. */
+ private static class TestModelDownloader implements ModelDownloader {
+ private final String strWrittenToFile;
+ private final ListenableFuture<Long> futureToReturn;
+
+ public static TestModelDownloader withSuccess(String strWrittenToFile) {
+ return new TestModelDownloader(
+ Futures.immediateFuture((long) strWrittenToFile.getBytes().length), strWrittenToFile);
+ }
+
+ public static TestModelDownloader withFailure(Throwable throwable) {
+ return new TestModelDownloader(Futures.immediateFailedFuture(throwable), null);
+ }
+
+ private TestModelDownloader(ListenableFuture<Long> futureToReturn, String strWrittenToFile) {
+ this.strWrittenToFile = strWrittenToFile;
+ this.futureToReturn = futureToReturn;
+ }
+
+ @Override
+ public ListenableFuture<Long> download(URI uri, File targetFile) {
+ if (strWrittenToFile != null) {
+ try {
+ targetFile.createNewFile();
+ FileWriter fileWriter = new FileWriter(targetFile);
+ fileWriter.write(strWrittenToFile, /* off= */ 0, strWrittenToFile.length());
+ fileWriter.close();
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to prepare test downloadeded file.", e);
+ }
+ }
+ return futureToReturn;
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index d248eb0..1c4f7f8 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -232,7 +232,7 @@
.collect(Collectors.toList()));
assertThat(loggedEvents).hasSize(1);
TextClassifierApiUsageReported loggedEvent = loggedEvents.get(0);
- assertThat(loggedEvent.getLatencyMillis()).isGreaterThan(0);
+ assertThat(loggedEvent.getLatencyMillis()).isGreaterThan(0L);
assertThat(loggedEvent.getApiType()).isEqualTo(expectedApiType);
assertThat(loggedEvent.getResultType()).isEqualTo(expectedResultApiType);
assertThat(loggedEvent.getSessionId()).isEqualTo(SESSION_ID);
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DownloaderTestUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/DownloaderTestUtils.java
new file mode 100644
index 0000000..980dda3
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DownloaderTestUtils.java
@@ -0,0 +1,95 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.content.Context;
+import androidx.work.ListenableWorker;
+import androidx.work.WorkInfo;
+import androidx.work.WorkManager;
+import androidx.work.WorkQuery;
+import androidx.work.WorkerParameters;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import java.util.List;
+
+/** Utils for downloader logic testing. */
+final class DownloaderTestUtils {
+
+ /** One unique queue holds at most one request at one time. Returns null if no WorkInfo found. */
+ public static WorkInfo queryTheOnlyWorkInfo(WorkManager workManager, String queueName)
+ throws Exception {
+ WorkQuery workQuery =
+ WorkQuery.Builder.fromUniqueWorkNames(ImmutableList.of(queueName)).build();
+ List<WorkInfo> workInfos = workManager.getWorkInfos(workQuery).get();
+ if (workInfos.isEmpty()) {
+ return null;
+ } else {
+ return Iterables.getOnlyElement(workInfos);
+ }
+ }
+
+ /**
+ * Completes immediately with the pre-set result. If it's not retry, the result will also include
+ * the input Data as its output Data.
+ */
+ public static final class TestWorker extends ListenableWorker {
+ private static Result expectedResult;
+
+ public TestWorker(Context context, WorkerParameters workerParams) {
+ super(context, workerParams);
+ }
+
+ @Override
+ public ListenableFuture<ListenableWorker.Result> startWork() {
+ if (expectedResult == null) {
+ return Futures.immediateFailedFuture(new Exception("no expected result"));
+ }
+ ListenableWorker.Result result;
+ switch (expectedResult) {
+ case SUCCESS:
+ result = ListenableWorker.Result.success(getInputData());
+ break;
+ case FAILURE:
+ result = ListenableWorker.Result.failure(getInputData());
+ break;
+ case RETRY:
+ result = ListenableWorker.Result.retry();
+ break;
+ default:
+ throw new IllegalStateException("illegal result");
+ }
+ // Reset expected result
+ expectedResult = null;
+ return Futures.immediateFuture(result);
+ }
+
+ /** Sets the expected worker result in a static variable. Will be cleaned up after reading. */
+ public static void setExpectedResult(Result expectedResult) {
+ TestWorker.expectedResult = expectedResult;
+ }
+
+ public enum Result {
+ SUCCESS,
+ FAILURE,
+ RETRY;
+ }
+ }
+
+ private DownloaderTestUtils() {}
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ManifestDownloadWorkerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ManifestDownloadWorkerTest.java
new file mode 100644
index 0000000..38fdf47
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ManifestDownloadWorkerTest.java
@@ -0,0 +1,152 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.work.ListenableWorker;
+import androidx.work.WorkInfo;
+import androidx.work.WorkManager;
+import androidx.work.WorkerFactory;
+import androidx.work.WorkerParameters;
+import androidx.work.testing.TestDriver;
+import androidx.work.testing.TestListenableWorkerBuilder;
+import androidx.work.testing.WorkManagerTestInitHelper;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import java.io.File;
+import java.nio.file.Files;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class ManifestDownloadWorkerTest {
+ private static final String MODEL_URL =
+ "https://www.gstatic.com/android/text_classifier/q/v711/en.fb";
+ private static final long MODEL_SIZE_IN_BYTES = 1L;
+ private static final String MODEL_FINGERPRINT = "hash_fingerprint";
+ private static final String MANIFEST_URL =
+ "https://www.gstatic.com/android/text_classifier/q/v711/en.fb.manifest";
+ private static final String TARGET_MODEL_PATH = "/not_used_fake_path.fb";
+ private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+ private static final String MODEL_LANGUAGE_TAG = "en";
+ private static final String WORK_MANAGER_UNIQUE_WORK_NAME =
+ ModelDownloadManager.getModelUniqueWorkName(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ private static final int WORKER_MAX_DOWNLOAD_ATTEMPTS = 5;
+ private static final ModelManifest MODEL_MANIFEST_PROTO =
+ ModelManifest.newBuilder()
+ .addModels(
+ ModelManifest.Model.newBuilder()
+ .setUrl(MODEL_URL)
+ .setSizeInBytes(MODEL_SIZE_IN_BYTES)
+ .setFingerprint(MODEL_FINGERPRINT)
+ .build())
+ .build();
+
+ private File manifestFile;
+ private WorkManager workManager;
+ private TestDriver workManagerTestDriver;
+
+ @Before
+ public void setUp() {
+ Context context = ApplicationProvider.getApplicationContext();
+ WorkManagerTestInitHelper.initializeTestWorkManager(context);
+
+ this.manifestFile = new File(context.getCacheDir(), "model.fb.manifest");
+ this.workManager = WorkManager.getInstance(context);
+ this.workManagerTestDriver = WorkManagerTestInitHelper.getTestDriver(context);
+
+ manifestFile.deleteOnExit();
+ }
+
+ @Test
+ public void enqueueSuccessfullyAndCheckData() throws Exception {
+ ManifestDownloadWorker worker = createWorker(MANIFEST_URL, manifestFile.getAbsolutePath());
+
+ // We only want to test the downloaded file handling code, so reuse existing manifest file
+ manifestFile.createNewFile();
+ Files.write(manifestFile.toPath(), MODEL_MANIFEST_PROTO.toByteArray());
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(manifestFile.exists()).isTrue();
+
+ WorkInfo workInfo =
+ DownloaderTestUtils.queryTheOnlyWorkInfo(workManager, WORK_MANAGER_UNIQUE_WORK_NAME);
+ assertThat(workInfo).isNotNull();
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(workInfo.getTags()).contains(MANIFEST_URL);
+
+ // Check input Data with TestWorker
+ DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
+ workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
+
+ WorkInfo newWorkInfo =
+ DownloaderTestUtils.queryTheOnlyWorkInfo(workManager, WORK_MANAGER_UNIQUE_WORK_NAME);
+ assertThat(newWorkInfo.getId()).isEqualTo(workInfo.getId());
+ assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.SUCCEEDED);
+ assertThat(newWorkInfo.getOutputData().getString(AbstractDownloadWorker.DATA_URL_KEY))
+ .isEqualTo(MODEL_URL);
+ assertThat(
+ newWorkInfo
+ .getOutputData()
+ .getLong(ModelDownloadWorker.DATA_MODEL_SIZE_IN_BYTES_KEY, /* defaultValue= */ -1))
+ .isEqualTo(MODEL_SIZE_IN_BYTES);
+ assertThat(
+ newWorkInfo.getOutputData().getString(ModelDownloadWorker.DATA_MODEL_FINGERPRINT_KEY))
+ .isEqualTo(MODEL_FINGERPRINT);
+ assertThat(
+ newWorkInfo.getOutputData().getString(ModelDownloadWorker.DATA_TARGET_MODEL_PATH_KEY))
+ .isEqualTo(TARGET_MODEL_PATH);
+ }
+
+ @Test
+ public void invalidManifestFile_invalidFileDeletedAndRetry() throws Exception {
+ ManifestDownloadWorker worker = createWorker(MANIFEST_URL, manifestFile.getAbsolutePath());
+
+ manifestFile.createNewFile();
+ Files.write(manifestFile.toPath(), "random_content".getBytes());
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ assertThat(manifestFile.exists()).isFalse();
+ }
+
+ private static ManifestDownloadWorker createWorker(String manifestUrl, String manifestPath) {
+ return TestListenableWorkerBuilder.from(
+ ApplicationProvider.getApplicationContext(), ManifestDownloadWorker.class)
+ .setInputData(
+ ManifestDownloadWorker.createInputData(
+ MODEL_TYPE,
+ MODEL_LANGUAGE_TAG,
+ manifestUrl,
+ manifestPath,
+ TARGET_MODEL_PATH,
+ WORKER_MAX_DOWNLOAD_ATTEMPTS,
+ /* reuseExistingManifestFile= */ true))
+ .setRunAttemptCount(0)
+ .setWorkerFactory(
+ new WorkerFactory() {
+ @Override
+ public ListenableWorker createWorker(
+ Context appContext, String workerClassName, WorkerParameters workerParameters) {
+ return new ManifestDownloadWorker(
+ appContext, workerParameters, DownloaderTestUtils.TestWorker.class);
+ }
+ })
+ .build();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloadManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloadManagerTest.java
new file mode 100644
index 0000000..8564130
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloadManagerTest.java
@@ -0,0 +1,412 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.content.Context;
+import android.os.LocaleList;
+import android.provider.DeviceConfig.Properties;
+import androidx.annotation.NonNull;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.work.ExistingWorkPolicy;
+import androidx.work.OneTimeWorkRequest;
+import androidx.work.WorkInfo;
+import androidx.work.WorkManager;
+import androidx.work.testing.TestDriver;
+import androidx.work.testing.WorkManagerTestInitHelper;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.util.HashMap;
+import java.util.Locale;
+import javax.annotation.Nullable;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.MockitoAnnotations;
+
+@RunWith(AndroidJUnit4.class)
+public final class ModelDownloadManagerTest {
+ private static final String URL_PREFIX = ModelDownloadManager.TEXT_CLASSIFIER_URL_PREFIX;
+ private static final String URL_SUFFIX = "abc.xyz";
+ private static final String URL_SUFFIX_2 = "def.xyz";
+ private static final String URL = URL_PREFIX + URL_SUFFIX;
+ private static final String URL_2 = URL_PREFIX + URL_SUFFIX_2;
+ // Parameterized test is not yet supported for instrumentation test
+ @ModelType.ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
+ @ModelType.ModelTypeDef private static final String MODEL_TYPE_2 = ModelType.ACTIONS_SUGGESTIONS;
+ private static final String MODEL_LANGUAGE_TAG = "en";
+ private static final String MODEL_LANGUAGE_TAG_2 = "zh";
+ private static final String MODEL_LANGUAGE_UNIVERSAL_TAG =
+ ModelDownloadManager.UNIVERSAL_MODEL_LANGUAGE_TAG;
+ private static final LocaleList DEFAULT_LOCALE_LIST =
+ new LocaleList(new Locale(MODEL_LANGUAGE_TAG));
+
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+
+ // TODO(licha): Maybe we can just use the real TextClassifierSettings
+ private FakeDeviceConfig fakeDeviceConfig;
+ private WorkManager workManager;
+ private TestDriver workManagerTestDriver;
+ private File downloadTargetFile;
+ private ModelDownloadManager downloadManager;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+ Context context = ApplicationProvider.getApplicationContext();
+ WorkManagerTestInitHelper.initializeTestWorkManager(context);
+
+ this.fakeDeviceConfig = new FakeDeviceConfig();
+ this.workManager = WorkManager.getInstance(context);
+ this.workManagerTestDriver = WorkManagerTestInitHelper.getTestDriver(context);
+ ModelFileManager modelFileManager = new ModelFileManager(context, ImmutableList.of());
+ this.downloadTargetFile = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL_SUFFIX);
+ this.downloadManager =
+ new ModelDownloadManager(
+ workManager,
+ DownloaderTestUtils.TestWorker.class,
+ modelFileManager,
+ new TextClassifierSettings(fakeDeviceConfig),
+ MoreExecutors.newDirectExecutorService());
+ setDefaultLocalesRule.set(DEFAULT_LOCALE_LIST);
+ }
+
+ @After
+ public void tearDown() {
+ recursiveDelete(ApplicationProvider.getApplicationContext().getFilesDir());
+ }
+
+ @Test
+ public void init_checkConfigWhenInit() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.init();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_flagNotSet() throws Exception {
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo).isNull();
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_fileAlreadyExists() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ try {
+ downloadTargetFile.getParentFile().mkdirs();
+ downloadTargetFile.createNewFile();
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo).isNull();
+ } finally {
+ downloadTargetFile.delete();
+ }
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_doNotRedownloadTheSameModel() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ // Simulates a previous model download task
+ OneTimeWorkRequest modelDownloadRequest =
+ new OneTimeWorkRequest.Builder(DownloaderTestUtils.TestWorker.class).addTag(URL).build();
+ DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
+ workManager
+ .enqueueUniqueWork(
+ ModelDownloadManager.getModelUniqueWorkName(MODEL_TYPE, MODEL_LANGUAGE_TAG),
+ ExistingWorkPolicy.REPLACE,
+ modelDownloadRequest)
+ .getResult()
+ .get();
+
+ // Assert the model download work succeeded
+ WorkInfo succeededModelWorkInfo =
+ DownloaderTestUtils.queryTheOnlyWorkInfo(
+ workManager,
+ ModelDownloadManager.getModelUniqueWorkName(MODEL_TYPE, MODEL_LANGUAGE_TAG));
+ assertThat(succeededModelWorkInfo.getState()).isEqualTo(WorkInfo.State.SUCCEEDED);
+
+ // Trigger the config check
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo).isNull();
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_requestEnqueuedSuccessfully() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_multipleModelsEnqueued() throws Exception {
+ for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
+ setUpModelUrlSuffix(modelType, MODEL_LANGUAGE_TAG, modelType + URL_SUFFIX);
+ }
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
+ WorkInfo workInfo = queryTheOnlyWorkInfo(modelType, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ }
+ }
+
+ // This test is to make sure we will not schedule a new task if another task exists with the same
+ // url tag, even if it's in a different queue. Currently we schedule both manifest and model
+ // download tasks with the same model url tag. This behavior protects us from unintended task
+ // overriding.
+ @Test
+ public void checkConfigAndScheduleDownloads_urlIsCheckedGlobally() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ WorkInfo workInfo1 = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo1.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+
+ // Set the same url to a different model type flag
+ setUpModelUrlSuffix(MODEL_TYPE_2, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ workInfo1 = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo1.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ WorkInfo workInfo2 = queryTheOnlyWorkInfo(MODEL_TYPE_2, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo2).isNull();
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_checkMultipleTimes() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ WorkInfo oldWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+
+ // Will not schedule multiple times, still the same WorkInfo
+ assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(oldWorkInfo.getId()).isEqualTo(newWorkInfo.getId());
+ assertThat(oldWorkInfo.getTags()).containsExactlyElementsIn(newWorkInfo.getTags());
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_flagUpdatedWhilePrevDownloadPending()
+ throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ WorkInfo oldWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX_2);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+
+ // oldWorkInfo will be replaced with the newWorkInfo
+ assertThat(oldWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(oldWorkInfo.getId()).isNotEqualTo(newWorkInfo.getId());
+ assertThat(oldWorkInfo.getTags()).contains(URL);
+ assertThat(newWorkInfo.getTags()).contains(URL_2);
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_flagUpdatedAfterPrevDownloadDone() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ WorkInfo oldWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ // Run scheduled download
+ DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
+ workManagerTestDriver.setAllConstraintsMet(oldWorkInfo.getId());
+ try {
+ // Create download file
+ downloadTargetFile.createNewFile();
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ // Update device config
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX_2);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+ WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+
+ // Assert new request can be queued successfully
+ assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(newWorkInfo.getTags()).contains(URL_2);
+ assertThat(oldWorkInfo.getId()).isNotEqualTo(newWorkInfo.getId());
+ } finally {
+ downloadTargetFile.delete();
+ }
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_workerSucceeded() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+
+ DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.SUCCESS);
+ workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
+
+ WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(newWorkInfo.getId()).isEqualTo(workInfo.getId());
+ assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.SUCCEEDED);
+ assertThat(newWorkInfo.getOutputData().getString(AbstractDownloadWorker.DATA_URL_KEY))
+ .isEqualTo(URL);
+ assertThat(
+ newWorkInfo.getOutputData().getString(AbstractDownloadWorker.DATA_DESTINATION_PATH_KEY))
+ .isEqualTo(
+ ModelDownloadManager.getTargetManifestPath(downloadTargetFile.getAbsolutePath()));
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_workerFailed() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+
+ DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.FAILURE);
+ workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
+
+ WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(newWorkInfo.getId()).isEqualTo(workInfo.getId());
+ assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.FAILED);
+ assertThat(newWorkInfo.getOutputData().getString(AbstractDownloadWorker.DATA_URL_KEY))
+ .isEqualTo(URL);
+ assertThat(
+ newWorkInfo.getOutputData().getString(AbstractDownloadWorker.DATA_DESTINATION_PATH_KEY))
+ .isEqualTo(
+ ModelDownloadManager.getTargetManifestPath(downloadTargetFile.getAbsolutePath()));
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_workerRetried() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG, URL_SUFFIX);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+
+ DownloaderTestUtils.TestWorker.setExpectedResult(DownloaderTestUtils.TestWorker.Result.RETRY);
+ workManagerTestDriver.setAllConstraintsMet(workInfo.getId());
+
+ WorkInfo newWorkInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG);
+ assertThat(newWorkInfo.getId()).isEqualTo(workInfo.getId());
+ assertThat(newWorkInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(newWorkInfo.getRunAttemptCount()).isEqualTo(1);
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_chooseTheBestLocaleTag() throws Exception {
+ // System default locale: zh-hant-hk
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("zh-hant-hk")));
+
+ // All configured locale tags
+ setUpModelUrlSuffix(MODEL_TYPE, "zh-hant", URL_SUFFIX); // best match
+ setUpModelUrlSuffix(MODEL_TYPE, "zh", URL_SUFFIX_2); // too general
+ setUpModelUrlSuffix(MODEL_TYPE, "zh-hk", URL_SUFFIX_2); // missing script
+ setUpModelUrlSuffix(MODEL_TYPE, "zh-hans-hk", URL_SUFFIX_2); // incorrect script
+ setUpModelUrlSuffix(MODEL_TYPE, "es-hant-hk", URL_SUFFIX_2); // incorrect language
+
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ // The downloader choose: zh-hant
+ assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hant").getState())
+ .isEqualTo(WorkInfo.State.ENQUEUED);
+
+ assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh")).isNull();
+ assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hk")).isNull();
+ assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "zh-hans-hk")).isNull();
+ assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, "es-hant-hk")).isNull();
+ }
+
+ @Test
+ public void checkConfigAndScheduleDownloads_useUniversalModelIfNoMatchedTag() throws Exception {
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_TAG_2, URL_SUFFIX);
+ setUpModelUrlSuffix(MODEL_TYPE, MODEL_LANGUAGE_UNIVERSAL_TAG, URL_SUFFIX_2);
+ downloadManager.checkConfigAndScheduleDownloadsForTesting();
+
+ assertThat(queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_TAG_2)).isNull();
+
+ WorkInfo workInfo = queryTheOnlyWorkInfo(MODEL_TYPE, MODEL_LANGUAGE_UNIVERSAL_TAG);
+ assertThat(workInfo.getState()).isEqualTo(WorkInfo.State.ENQUEUED);
+ assertThat(workInfo.getTags()).contains(URL_2);
+ }
+
+ private void setUpModelUrlSuffix(
+ @ModelType.ModelTypeDef String modelType, String modelLanguageTag, String urlSuffix) {
+ String deviceConfigFlag =
+ String.format(
+ TextClassifierSettings.MANIFEST_URL_SUFFIX_TEMPLATE, modelType, modelLanguageTag);
+ fakeDeviceConfig.setConfig(deviceConfigFlag, urlSuffix);
+ }
+
+ /** One unique queue holds at most one request at one time. Returns null if no WorkInfo found. */
+ private WorkInfo queryTheOnlyWorkInfo(
+ @ModelType.ModelTypeDef String modelType, String modelLanguageTag) throws Exception {
+ return DownloaderTestUtils.queryTheOnlyWorkInfo(
+ workManager, ModelDownloadManager.getManifestUniqueWorkName(modelType, modelLanguageTag));
+ }
+
+ private static void recursiveDelete(File f) {
+ if (f.isDirectory()) {
+ for (File innerFile : f.listFiles()) {
+ recursiveDelete(innerFile);
+ }
+ }
+ f.delete();
+ }
+
+ private static class FakeDeviceConfig implements TextClassifierSettings.IDeviceConfig {
+
+ private final HashMap<String, String> configs;
+
+ public FakeDeviceConfig() {
+ this.configs = new HashMap<>();
+ }
+
+ public void setConfig(String key, String value) {
+ configs.put(key, value);
+ }
+
+ @Override
+ public Properties getProperties(@NonNull String namespace, @NonNull String... names) {
+ Properties.Builder builder = new Properties.Builder(namespace);
+ for (String key : configs.keySet()) {
+ builder.setString(key, configs.get(key));
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String getString(
+ @NonNull String namespace, @NonNull String name, @Nullable String defaultValue) {
+ return configs.containsKey(name) ? configs.get(name) : defaultValue;
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloadWorkerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloadWorkerTest.java
new file mode 100644
index 0000000..107aa95
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloadWorkerTest.java
@@ -0,0 +1,177 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.testng.Assert.expectThrows;
+
+import android.content.Context;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.work.ListenableWorker;
+import androidx.work.WorkerFactory;
+import androidx.work.WorkerParameters;
+import androidx.work.testing.TestListenableWorkerBuilder;
+import androidx.work.testing.WorkManagerTestInitHelper;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class ModelDownloadWorkerTest {
+ private static final String MODEL_URL =
+ "http://www.gstatic.com/android/text_classifier/q/v711/en.fb";
+ private static final String MODEL_CONTENT = "content";
+ private static final String MODEL_CONTENT_CORRUPTED = "CONTENT";
+ private static final long MODEL_SIZE_IN_BYTES = 7L;
+ private static final String MODEL_FINGERPRINT =
+ "5406ebea1618e9b73a7290c5d716f0b47b4f1fbc5d8c"
+ + "5e78c9010a3e01c18d8594aa942e3536f7e01574245d34647523";
+ private static final int WORKER_MAX_DOWNLOAD_ATTEMPTS = 5;
+
+ private File manifestFile;
+ private File pendingModelFile;
+ private File targetModelFile;
+
+ @Before
+ public void setUp() {
+ Context context = ApplicationProvider.getApplicationContext();
+ WorkManagerTestInitHelper.initializeTestWorkManager(context);
+
+ this.manifestFile = new File(context.getCacheDir(), "model.fb.manifest");
+ this.pendingModelFile = new File(context.getCacheDir(), "model.fb.pending");
+ this.targetModelFile = new File(context.getCacheDir(), "model.fb");
+ }
+
+ @After
+ public void tearDown() {
+ manifestFile.delete();
+ pendingModelFile.delete();
+ targetModelFile.delete();
+ }
+
+ @Test
+ public void passedVerificationAndMoved() throws Exception {
+ ModelDownloadWorker worker = createWorker(manifestFile, pendingModelFile, targetModelFile);
+ manifestFile.createNewFile();
+ writeToFile(pendingModelFile, MODEL_CONTENT);
+
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(targetModelFile.exists()).isTrue();
+ assertThat(pendingModelFile.exists()).isFalse();
+ assertThat(manifestFile.exists()).isFalse();
+ }
+
+ @Test
+ public void passedVerificationAndReplaced() throws Exception {
+ ModelDownloadWorker worker = createWorker(manifestFile, pendingModelFile, targetModelFile);
+ manifestFile.createNewFile();
+ writeToFile(pendingModelFile, MODEL_CONTENT);
+ writeToFile(targetModelFile, MODEL_CONTENT);
+
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.success());
+ assertThat(targetModelFile.exists()).isTrue();
+ assertThat(pendingModelFile.exists()).isFalse();
+ assertThat(manifestFile.exists()).isFalse();
+ }
+
+ @Test
+ public void failedVerificationAndRetry() throws Exception {
+ ModelDownloadWorker worker = createWorker(manifestFile, pendingModelFile, targetModelFile);
+ manifestFile.createNewFile();
+ writeToFile(pendingModelFile, /* content= */ "");
+
+ assertThat(worker.startWork().get()).isEqualTo(ListenableWorker.Result.retry());
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(pendingModelFile.exists()).isFalse();
+ assertThat(manifestFile.exists()).isTrue();
+ }
+
+ @Test
+ public void validateModel_validationPassed() throws Exception {
+ ModelDownloadWorker worker = createWorker(manifestFile, pendingModelFile, targetModelFile);
+ writeToFile(pendingModelFile, MODEL_CONTENT);
+ worker.handleDownloadedFile(pendingModelFile);
+ }
+
+ @Test
+ public void validateModel_fileDoesNotExist() throws Exception {
+ ModelDownloadWorker worker = createWorker(manifestFile, pendingModelFile, targetModelFile);
+ pendingModelFile.delete();
+ IllegalStateException e =
+ expectThrows(
+ IllegalStateException.class, () -> worker.handleDownloadedFile(pendingModelFile));
+ assertThat(e).hasCauseThat().hasMessageThat().contains("does not exist");
+ }
+
+ @Test
+ public void validateModel_emptyFile() throws Exception {
+ ModelDownloadWorker worker = createWorker(manifestFile, pendingModelFile, targetModelFile);
+ writeToFile(pendingModelFile, /* content= */ "");
+ IllegalStateException e =
+ expectThrows(
+ IllegalStateException.class, () -> worker.handleDownloadedFile(pendingModelFile));
+ assertThat(e).hasCauseThat().hasMessageThat().contains("size does not match");
+ }
+
+ @Test
+ public void validateModel_corruptedContent() throws Exception {
+ ModelDownloadWorker worker = createWorker(manifestFile, pendingModelFile, targetModelFile);
+ writeToFile(pendingModelFile, MODEL_CONTENT_CORRUPTED);
+ IllegalStateException e =
+ expectThrows(
+ IllegalStateException.class, () -> worker.handleDownloadedFile(pendingModelFile));
+ assertThat(e).hasCauseThat().hasMessageThat().contains("fingerprint does not match");
+ }
+
+ private static ModelDownloadWorker createWorker(
+ File manifestFile, File pendingModelFile, File targetModelFile) {
+ return TestListenableWorkerBuilder.from(
+ ApplicationProvider.getApplicationContext(), ModelDownloadWorker.class)
+ .setInputData(
+ ModelDownloadWorker.createInputData(
+ MODEL_URL,
+ MODEL_SIZE_IN_BYTES,
+ MODEL_FINGERPRINT,
+ manifestFile.getAbsolutePath(),
+ pendingModelFile.getAbsolutePath(),
+ targetModelFile.getAbsolutePath(),
+ WORKER_MAX_DOWNLOAD_ATTEMPTS,
+ /* reuseExistingModelFile= */ true))
+ .setRunAttemptCount(0)
+ .setWorkerFactory(
+ new WorkerFactory() {
+ @Override
+ public ListenableWorker createWorker(
+ Context appContext, String workerClassName, WorkerParameters workerParameters) {
+ return new ModelDownloadWorker(appContext, workerParameters);
+ }
+ })
+ .build();
+ }
+
+ private static void writeToFile(File file, String content) throws IOException {
+ file.createNewFile();
+ FileWriter fileWriter = new FileWriter(file);
+ fileWriter.write(content, /* off= */ 0, content.length());
+ fileWriter.close();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloaderImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloaderImplTest.java
new file mode 100644
index 0000000..806172d
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloaderImplTest.java
@@ -0,0 +1,116 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static java.util.concurrent.TimeUnit.SECONDS;
+import static org.testng.Assert.expectThrows;
+
+import android.content.Context;
+import androidx.test.core.app.ApplicationProvider;
+import com.android.textclassifier.TestModelDownloaderService.DownloadResult;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import java.io.File;
+import java.net.URI;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public final class ModelDownloaderImplTest {
+ private static final URI TEST_URI = URI.create("test_uri");
+
+ private ModelDownloaderImpl modelDownloaderImpl;
+ private File targetFile;
+
+ @Before
+ public void setUp() {
+ Context context = ApplicationProvider.getApplicationContext();
+ this.modelDownloaderImpl =
+ new ModelDownloaderImpl(
+ context, MoreExecutors.newDirectExecutorService(), TestModelDownloaderService.class);
+ this.targetFile = new File(context.getCacheDir(), "targetFile.fb");
+ }
+
+ @After
+ public void tearDown() {
+ TestModelDownloaderService.reset();
+ }
+
+ @Test
+ public void download_failToBind() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(false);
+ ListenableFuture<Long> bytesWrittenFuture = modelDownloaderImpl.download(TEST_URI, targetFile);
+
+ expectThrows(Throwable.class, bytesWrittenFuture::get);
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ }
+
+ @Test
+ public void download_succeed() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(DownloadResult.SUCCEEDED);
+ ListenableFuture<Long> bytesWrittenFuture = modelDownloaderImpl.download(TEST_URI, targetFile);
+
+ assertThat(bytesWrittenFuture.get()).isEqualTo(TestModelDownloaderService.BYTES_WRITTEN);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void download_fail() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(DownloadResult.FAILED);
+ ListenableFuture<Long> bytesWrittenFuture = modelDownloaderImpl.download(TEST_URI, targetFile);
+
+ Throwable t = expectThrows(Throwable.class, bytesWrittenFuture::get);
+ assertThat(t).hasMessageThat().contains(TestModelDownloaderService.ERROR_MSG);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+
+ @Test
+ public void download_cancelAndUnbind() throws Exception {
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isFalse();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+
+ TestModelDownloaderService.setBindSucceed(true);
+ TestModelDownloaderService.setDownloadResult(DownloadResult.RUNNING_FOREVER);
+ ListenableFuture<Long> bytesWrittenFuture = modelDownloaderImpl.download(TEST_URI, targetFile);
+ bytesWrittenFuture.cancel(true);
+
+ expectThrows(Throwable.class, bytesWrittenFuture::get);
+ assertThat(TestModelDownloaderService.getOnUnbindInvokedLatch().await(1L, SECONDS)).isTrue();
+ assertThat(TestModelDownloaderService.isBound()).isFalse();
+ assertThat(TestModelDownloaderService.hasEverBeenBound()).isTrue();
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloaderServiceImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloaderServiceImplTest.java
new file mode 100644
index 0000000..d50dc78
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelDownloaderServiceImplTest.java
@@ -0,0 +1,175 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.expectThrows;
+
+import androidx.test.core.app.ApplicationProvider;
+import com.google.android.downloader.DownloadConstraints;
+import com.google.android.downloader.DownloadRequest;
+import com.google.android.downloader.DownloadResult;
+import com.google.android.downloader.Downloader;
+import com.google.android.downloader.SimpleFileDownloadDestination;
+import com.google.common.util.concurrent.FluentFuture;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.common.util.concurrent.SettableFuture;
+import java.io.File;
+import java.net.URI;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+@RunWith(JUnit4.class)
+public final class ModelDownloaderServiceImplTest {
+ private static final long BYTES_WRITTEN = 1L;
+ private static final String DOWNLOAD_URI =
+ "https://www.gstatic.com/android/text_classifier/r/v999/en.fb";
+
+ @Mock private Downloader downloader;
+ private File targetModelFile;
+ private File targetMetadataFile;
+ private ModelDownloaderServiceImpl modelDownloaderServiceImpl;
+ private TestSuccessCallbackImpl successCallback;
+ private TestFailureCallbackImpl failureCallback;
+
+ @Before
+ public void setUp() {
+ MockitoAnnotations.initMocks(this);
+
+ this.targetModelFile =
+ new File(ApplicationProvider.getApplicationContext().getCacheDir(), "model.fb");
+ this.targetMetadataFile = ModelDownloaderServiceImpl.getMetadataFile(targetModelFile);
+ this.modelDownloaderServiceImpl =
+ new ModelDownloaderServiceImpl(MoreExecutors.newDirectExecutorService(), downloader);
+ this.successCallback = new TestSuccessCallbackImpl();
+ this.failureCallback = new TestFailureCallbackImpl();
+
+ targetModelFile.deleteOnExit();
+ targetMetadataFile.deleteOnExit();
+ when(downloader.newRequestBuilder(any(), any()))
+ .thenReturn(
+ DownloadRequest.newBuilder()
+ .setUri(URI.create(DOWNLOAD_URI))
+ .setDownloadConstraints(DownloadConstraints.NONE)
+ .setDestination(
+ new SimpleFileDownloadDestination(targetModelFile, targetMetadataFile)));
+ }
+
+ @Test
+ public void download_succeeded() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(
+ FluentFuture.from(Futures.immediateFuture(DownloadResult.create(BYTES_WRITTEN))));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), successCallback);
+
+ assertThat(successCallback.getBytesWrittenFuture().get()).isEqualTo(BYTES_WRITTEN);
+ assertThat(targetModelFile.exists()).isTrue();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_failed() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(FluentFuture.from(Futures.immediateFailedFuture(new Exception("err_msg"))));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), successCallback);
+
+ Throwable t =
+ expectThrows(Throwable.class, () -> successCallback.getBytesWrittenFuture().get());
+ assertThat(t).hasMessageThat().contains("err_msg");
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_succeeded_callbackFailed() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(
+ FluentFuture.from(Futures.immediateFuture(DownloadResult.create(BYTES_WRITTEN))));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), failureCallback);
+
+ assertThat(failureCallback.onSuccessCalled).isTrue();
+ assertThat(targetModelFile.exists()).isTrue();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ @Test
+ public void download_failed_callbackFailed() throws Exception {
+ targetModelFile.createNewFile();
+ targetMetadataFile.createNewFile();
+ when(downloader.execute(any()))
+ .thenReturn(FluentFuture.from(Futures.immediateFailedFuture(new Exception("err_msg"))));
+ modelDownloaderServiceImpl.download(
+ DOWNLOAD_URI, targetModelFile.getAbsolutePath(), failureCallback);
+
+ assertThat(failureCallback.onFailureCalled).isTrue();
+ assertThat(targetModelFile.exists()).isFalse();
+ assertThat(targetMetadataFile.exists()).isFalse();
+ }
+
+ // NOTICE: Had some problem mocking this AIDL interface, so created fake impls
+ private static final class TestSuccessCallbackImpl extends IModelDownloaderCallback.Stub {
+ private final SettableFuture<Long> bytesWrittenFuture = SettableFuture.<Long>create();
+
+ public ListenableFuture<Long> getBytesWrittenFuture() {
+ return bytesWrittenFuture;
+ }
+
+ @Override
+ public void onSuccess(long bytesWritten) {
+ bytesWrittenFuture.set(bytesWritten);
+ }
+
+ @Override
+ public void onFailure(String error) {
+ bytesWrittenFuture.setException(new RuntimeException(error));
+ }
+ }
+
+ private static final class TestFailureCallbackImpl extends IModelDownloaderCallback.Stub {
+ public boolean onSuccessCalled = false;
+ public boolean onFailureCalled = false;
+
+ @Override
+ public void onSuccess(long bytesWritten) {
+ onSuccessCalled = true;
+ throw new RuntimeException();
+ }
+
+ @Override
+ public void onFailure(String error) {
+ onFailureCalled = true;
+ throw new RuntimeException();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestModelDownloaderService.java b/java/tests/instrumentation/src/com/android/textclassifier/TestModelDownloaderService.java
new file mode 100644
index 0000000..ddef5c1
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestModelDownloaderService.java
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import android.app.Service;
+import android.content.Intent;
+import android.os.IBinder;
+import java.util.concurrent.CountDownLatch;
+
+/** Test Service of IModelDownloaderService. */
+public final class TestModelDownloaderService extends Service {
+ public static final String GOOD_URI = "good_uri";
+ public static final String BAD_URI = "bad_uri";
+ public static final long BYTES_WRITTEN = 1L;
+ public static final String ERROR_MSG = "not good uri";
+
+ public enum DownloadResult {
+ SUCCEEDED,
+ FAILED,
+ RUNNING_FOREVER,
+ DO_NOTHING
+ }
+
+ // Obviously this does not work when considering concurrency, but probably fine for test purpose
+ private static boolean boundBefore = false;
+ private static boolean boundNow = false;
+ private static CountDownLatch onUnbindInvokedLatch = new CountDownLatch(1);
+
+ private static boolean bindSucceed = false;
+ private static DownloadResult downloadResult = DownloadResult.SUCCEEDED;
+
+ public static boolean hasEverBeenBound() {
+ return boundBefore;
+ }
+
+ public static boolean isBound() {
+ return boundNow;
+ }
+
+ public static CountDownLatch getOnUnbindInvokedLatch() {
+ return onUnbindInvokedLatch;
+ }
+
+ public static void setBindSucceed(boolean bindSucceed) {
+ TestModelDownloaderService.bindSucceed = bindSucceed;
+ }
+
+ public static void setDownloadResult(DownloadResult result) {
+ TestModelDownloaderService.downloadResult = result;
+ }
+
+ public static void reset() {
+ boundBefore = false;
+ boundNow = false;
+ onUnbindInvokedLatch = new CountDownLatch(1);
+ bindSucceed = false;
+ }
+
+ @Override
+ public IBinder onBind(Intent intent) {
+ if (bindSucceed) {
+ boundBefore = true;
+ boundNow = true;
+ return new TestModelDownloaderServiceImpl();
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public boolean onUnbind(Intent intent) {
+ boundNow = false;
+ onUnbindInvokedLatch.countDown();
+ return false;
+ }
+
+ private static final class TestModelDownloaderServiceImpl extends IModelDownloaderService.Stub {
+ @Override
+ public void download(String uri, String unused, IModelDownloaderCallback callback) {
+ try {
+ switch (downloadResult) {
+ case SUCCEEDED:
+ callback.onSuccess(BYTES_WRITTEN);
+ break;
+ case FAILED:
+ callback.onFailure(ERROR_MSG);
+ break;
+ case RUNNING_FOREVER:
+ while (true) {}
+ case DO_NOTHING:
+ // Do nothing
+ }
+ } catch (Throwable t) {
+ // The test would timeout if failing to get the callback result
+ }
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
index 6d6887a..27ea7f0 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -115,7 +115,7 @@
assertThat(links).hasSize(1);
assertThat(links.get(0).getEntityCount()).isGreaterThan(0);
assertThat(links.get(0).getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
- assertThat(links.get(0).getConfidenceScore(TextClassifier.TYPE_URL)).isGreaterThan(0);
+ assertThat(links.get(0).getConfidenceScore(TextClassifier.TYPE_URL)).isGreaterThan(0f);
}
@Test
@@ -127,7 +127,7 @@
assertThat(textLanguage.getLocaleHypothesisCount()).isGreaterThan(0);
assertThat(textLanguage.getLocale(0).getLanguage()).isEqualTo("ja");
- assertThat(textLanguage.getConfidenceScore(ULocale.JAPANESE)).isGreaterThan(0);
+ assertThat(textLanguage.getConfidenceScore(ULocale.JAPANESE)).isGreaterThan(0f);
}
@Test
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 06ec640..f28732d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -21,7 +21,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
-import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.expectThrows;
import android.app.RemoteAction;
import android.content.Context;
@@ -333,7 +333,7 @@
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength() + 1];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
- assertThrows(IllegalArgumentException.class, () -> classifier.generateLinks(request));
+ expectThrows(IllegalArgumentException.class, () -> classifier.generateLinks(request));
}
@Test
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
index 9a754a3..b629efd 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
@@ -139,7 +139,8 @@
"url_suffix_lang_id_universal", "suffix:lang_id"),
settings ->
assertThat(settings.getLanguageTagsForManifestURLSuffix(ModelType.LANG_ID))
- .containsExactlyElementsIn(ImmutableList.of("universal")));
+ .containsExactlyElementsIn(
+ ImmutableList.of(ModelDownloadManager.UNIVERSAL_MODEL_LANGUAGE_TAG)));
assertSettings(
ImmutableMap.of(
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
index dfc09a7..fdc454d 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/intent/LabeledIntentTest.java
@@ -17,7 +17,7 @@
package com.android.textclassifier.common.intent;
import static com.google.common.truth.Truth.assertThat;
-import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.expectThrows;
import android.content.ComponentName;
import android.content.Context;
@@ -119,7 +119,7 @@
@Test
public void resolve_missingTitle() {
- assertThrows(
+ expectThrows(
IllegalArgumentException.class,
() -> new LabeledIntent(null, null, DESCRIPTION, null, INTENT, REQUEST_CODE));
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java
index e3e74b5..b9b7a95 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/common/statsd/TextClassifierApiUsageLoggerTest.java
@@ -90,7 +90,7 @@
TextClassifierApiUsageReported event = loggedEvents.get(0);
assertThat(event.getApiType()).isEqualTo(ApiType.SUGGEST_SELECTION);
assertThat(event.getResultType()).isEqualTo(ResultType.SUCCESS);
- assertThat(event.getLatencyMillis()).isGreaterThan(0);
+ assertThat(event.getLatencyMillis()).isGreaterThan(0L);
assertThat(event.getSessionId()).isEqualTo(SESSION_ID);
}
@@ -119,7 +119,7 @@
TextClassifierApiUsageReported event = loggedEvents.get(0);
assertThat(event.getApiType()).isEqualTo(ApiType.CLASSIFY_TEXT);
assertThat(event.getResultType()).isEqualTo(ResultType.FAIL);
- assertThat(event.getLatencyMillis()).isGreaterThan(0);
+ assertThat(event.getLatencyMillis()).isGreaterThan(0L);
assertThat(event.getSessionId()).isEqualTo(SESSION_ID);
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java b/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
deleted file mode 100644
index 38b53d4..0000000
--- a/java/tests/instrumentation/src/com/android/textclassifier/subjects/EntitySubject.java
+++ /dev/null
@@ -1,51 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package com.android.textclassifier.subjects;
-
-import static com.google.common.truth.Truth.assertAbout;
-
-import com.android.textclassifier.Entity;
-import com.google.common.truth.FailureMetadata;
-import com.google.common.truth.Subject;
-import javax.annotation.Nullable;
-
-/** Test helper for checking {@link com.android.textclassifier.Entity} results. */
-public final class EntitySubject extends Subject {
-
- private static final float TOLERANCE = 0.0001f;
-
- private final Entity entity;
-
- public static EntitySubject assertThat(@Nullable Entity entity) {
- return assertAbout(EntitySubject::new).that(entity);
- }
-
- private EntitySubject(FailureMetadata failureMetadata, @Nullable Entity entity) {
- super(failureMetadata, entity);
- this.entity = entity;
- }
-
- public void isMatchWithinTolerance(@Nullable Entity entity) {
- if (!entity.getEntityType().equals(this.entity.getEntityType())) {
- failWithActual("expected to have type", entity.getEntityType());
- }
- check("expected to have confidence score")
- .that(entity.getScore())
- .isWithin(TOLERANCE)
- .of(this.entity.getScore());
- }
-}