Export libtextclassifier
ExtServices APK now preloads universal model.
We mmap the model file straight from the APK, so we can't
compress those files when packing them to the APK.
Added "-0 model" as the appt flag to do so.
Bug: 169395238
Test: atest -p external/libtextclassifier
Change-Id: I3d59ec717a14f71be653159ecc7dd3e87bc9a80e
diff --git a/java/Android.bp b/java/Android.bp
index 893b423..981d7c9 100644
--- a/java/Android.bp
+++ b/java/Android.bp
@@ -35,6 +35,9 @@
sdk_version: "system_current",
min_sdk_version: "30",
manifest: "AndroidManifest.xml",
+ aaptflags: [
+ "-0 .model",
+ ],
}
// Similar to TextClassifierServiceLib, but without the AndroidManifest.
@@ -52,6 +55,10 @@
],
sdk_version: "system_current",
min_sdk_version: "30",
+ aaptflags: [
+ "-0 .model",
+ ],
+
}
java_library {
diff --git a/java/assets/textclassifier/actions_suggestions.universal.model b/java/assets/textclassifier/actions_suggestions.universal.model
new file mode 100755
index 0000000..f74fed4
--- /dev/null
+++ b/java/assets/textclassifier/actions_suggestions.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/annotator.universal.model b/java/assets/textclassifier/annotator.universal.model
new file mode 100755
index 0000000..09f1e0b
--- /dev/null
+++ b/java/assets/textclassifier/annotator.universal.model
Binary files differ
diff --git a/java/assets/textclassifier/lang_id.model b/java/assets/textclassifier/lang_id.model
new file mode 100644
index 0000000..e94dada
--- /dev/null
+++ b/java/assets/textclassifier/lang_id.model
Binary files differ
diff --git a/java/src/com/android/textclassifier/DefaultTextClassifierService.java b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
index 7de1acc..8c33ffe 100644
--- a/java/src/com/android/textclassifier/DefaultTextClassifierService.java
+++ b/java/src/com/android/textclassifier/DefaultTextClassifierService.java
@@ -16,7 +16,10 @@
package com.android.textclassifier;
+import android.content.BroadcastReceiver;
import android.content.Context;
+import android.content.Intent;
+import android.content.IntentFilter;
import android.os.CancellationSignal;
import android.service.textclassifier.TextClassifierService;
import android.view.textclassifier.ConversationActions;
@@ -55,6 +58,7 @@
private TextClassifierImpl textClassifier;
private TextClassifierSettings settings;
private ModelFileManager modelFileManager;
+ private BroadcastReceiver localeChangedReceiver;
public DefaultTextClassifierService() {
this.injector = new InjectorImpl(this);
@@ -76,14 +80,20 @@
normPriorityExecutor = injector.createNormPriorityExecutor();
lowPriorityExecutor = injector.createLowPriorityExecutor();
textClassifier = injector.createTextClassifierImpl(settings, modelFileManager);
+ localeChangedReceiver = new LocaleChangedReceiver(modelFileManager);
textClassifierApiUsageLogger =
injector.createTextClassifierApiUsageLogger(settings, lowPriorityExecutor);
+
+ injector
+ .getContext()
+ .registerReceiver(localeChangedReceiver, new IntentFilter(Intent.ACTION_LOCALE_CHANGED));
}
@Override
public void onDestroy() {
super.onDestroy();
+ injector.getContext().unregisterReceiver(localeChangedReceiver);
}
@Override
@@ -221,6 +231,22 @@
MoreExecutors.directExecutor());
}
+ /**
+ * Receiver listening to locale change event. Ask ModelFileManager to do clean-up upon receiving.
+ */
+ static class LocaleChangedReceiver extends BroadcastReceiver {
+ private final ModelFileManager modelFileManager;
+
+ LocaleChangedReceiver(ModelFileManager modelFileManager) {
+ this.modelFileManager = modelFileManager;
+ }
+
+ @Override
+ public void onReceive(Context context, Intent intent) {
+ modelFileManager.deleteUnusedModelFiles();
+ }
+ }
+
// Do not call any of these methods, except the constructor, before Service.onCreate is called.
private static class InjectorImpl implements Injector {
// Do not access the context object before Service.onCreate is invoked.
@@ -231,6 +257,11 @@
}
@Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
return new ModelFileManager(context, settings);
}
@@ -280,6 +311,8 @@
* class testable.
*/
interface Injector {
+ Context getContext();
+
ModelFileManager createModelFileManager(TextClassifierSettings settings);
TextClassifierSettings createTextClassifierSettings();
diff --git a/java/src/com/android/textclassifier/ModelFileManager.java b/java/src/com/android/textclassifier/ModelFileManager.java
index 39a70bc..9bc31fb 100644
--- a/java/src/com/android/textclassifier/ModelFileManager.java
+++ b/java/src/com/android/textclassifier/ModelFileManager.java
@@ -17,13 +17,15 @@
package com.android.textclassifier;
import android.content.Context;
+import android.content.res.AssetFileDescriptor;
+import android.content.res.AssetManager;
import android.os.LocaleList;
import android.os.ParcelFileDescriptor;
-import android.text.TextUtils;
+import android.util.ArraySet;
import androidx.annotation.GuardedBy;
import androidx.annotation.StringDef;
-import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import androidx.collection.ArrayMap;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
import com.android.textclassifier.utils.IndentingPrintWriter;
@@ -31,64 +33,104 @@
import com.google.android.textclassifier.AnnotatorModel;
import com.google.android.textclassifier.LangIdModel;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Function;
import com.google.common.base.Optional;
import com.google.common.base.Preconditions;
-import com.google.common.base.Splitter;
+import com.google.common.base.Supplier;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import java.io.File;
-import java.io.FileNotFoundException;
import java.io.IOException;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
+import java.util.Map;
import java.util.Objects;
-import java.util.function.Function;
-import java.util.function.Supplier;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
-// TODO(licha): Support garbage collection to delete unused model files
+// TODO(licha): Consider making this a singleton class
+// TODO(licha): Check whether this is thread-safe
/**
* Manages all model files in storage. {@link TextClassifierImpl} depends on this class to get the
* model files to load.
*/
final class ModelFileManager {
+
private static final String TAG = "ModelFileManager";
+
private static final String DOWNLOAD_SUB_DIR_NAME = "textclassifier/downloads/models/";
+ private static final File CONFIG_UPDATER_DIR = new File("/data/misc/textclassifier/");
+ private static final String ASSETS_DIR = "textclassifier";
- private final File downloadModelDir;
- private final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers;
+ private final List<ModelFileLister> modelFileListers;
+ private final File modelDownloaderDir;
- /** Create a ModelFileManager based on hardcoded model file locations. */
public ModelFileManager(Context context, TextClassifierSettings settings) {
Preconditions.checkNotNull(context);
Preconditions.checkNotNull(settings);
- this.downloadModelDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
- if (!downloadModelDir.exists()) {
- downloadModelDir.mkdirs();
- }
- ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>> suppliersBuilder =
- ImmutableMap.builder();
- for (String modelType : ModelType.values()) {
- suppliersBuilder.put(
- modelType, new ModelFileSupplierImpl(settings, modelType, downloadModelDir));
- }
- this.modelFileSuppliers = suppliersBuilder.build();
+ AssetManager assetManager = context.getAssets();
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ modelFileListers =
+ ImmutableList.of(
+ // Annotator models.
+ new RegularFilePatternMatchLister(
+ ModelType.ANNOTATOR,
+ this.modelDownloaderDir,
+ "annotator\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR,
+ new File(CONFIG_UPDATER_DIR, "textclassifier.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ANNOTATOR,
+ ASSETS_DIR,
+ "annotator\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // Actions models.
+ new RegularFilePatternMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ this.modelDownloaderDir,
+ "actions_suggestions\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS,
+ new File(CONFIG_UPDATER_DIR, "actions_suggestions.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.ACTIONS_SUGGESTIONS,
+ ASSETS_DIR,
+ "actions_suggestions\\.(.*)\\.model",
+ /* isEnabled= */ () -> true),
+ // LangID models.
+ new RegularFilePatternMatchLister(
+ ModelType.LANG_ID,
+ this.modelDownloaderDir,
+ "lang_id\\.(.*)\\.model",
+ settings::isModelDownloadManagerEnabled),
+ new RegularFileFullMatchLister(
+ ModelType.LANG_ID,
+ new File(CONFIG_UPDATER_DIR, "lang_id.model"),
+ /* isEnabled= */ () -> true),
+ new AssetFilePatternMatchLister(
+ assetManager,
+ ModelType.LANG_ID,
+ ASSETS_DIR,
+ "lang_id.model",
+ /* isEnabled= */ () -> true));
}
@VisibleForTesting
- ModelFileManager(
- File downloadModelDir,
- ImmutableMap<String, Supplier<ImmutableList<ModelFile>>> modelFileSuppliers) {
- this.downloadModelDir = Preconditions.checkNotNull(downloadModelDir);
- this.modelFileSuppliers = Preconditions.checkNotNull(modelFileSuppliers);
+ ModelFileManager(Context context, List<ModelFileLister> modelFileListers) {
+ this.modelDownloaderDir = new File(context.getFilesDir(), DOWNLOAD_SUB_DIR_NAME);
+ this.modelFileListers = ImmutableList.copyOf(modelFileListers);
}
/**
@@ -96,27 +138,203 @@
*
* @param modelType which type of model files to look for
*/
- public ImmutableList<ModelFile> listModelFiles(@ModelType.ModelTypeDef String modelType) {
- if (modelFileSuppliers.containsKey(modelType)) {
- return modelFileSuppliers.get(modelType).get();
+ public ImmutableList<ModelFile> listModelFiles(@ModelTypeDef String modelType) {
+ Preconditions.checkNotNull(modelType);
+
+ ImmutableList.Builder<ModelFile> modelFiles = new ImmutableList.Builder<>();
+ for (ModelFileLister modelFileLister : modelFileListers) {
+ modelFiles.addAll(modelFileLister.list(modelType));
}
- return ImmutableList.of();
+ return modelFiles.build();
+ }
+
+ /** Lists model files. */
+ public interface ModelFileLister {
+ List<ModelFile> list(@ModelTypeDef String modelType);
+ }
+
+ /** Lists model files by performing full match on file path. */
+ public static class RegularFileFullMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File targetFile;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param targetFile the expected model file
+ * @param isEnabled whether this lister is enabled
+ */
+ public RegularFileFullMatchLister(
+ @ModelTypeDef String modelType, File targetFile, Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.targetFile = Preconditions.checkNotNull(targetFile);
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!targetFile.exists()) {
+ return ImmutableList.of();
+ }
+ try {
+ return ImmutableList.of(ModelFile.createFromRegularFile(targetFile, modelType));
+ } catch (IOException e) {
+ TcLog.e(
+ TAG, "Failed to call createFromRegularFile with: " + targetFile.getAbsolutePath(), e);
+ }
+ return ImmutableList.of();
+ }
+ }
+
+ /** Lists model file in a specified folder by doing pattern matching on file names. */
+ public static class RegularFilePatternMatchLister implements ModelFileLister {
+ private final String modelType;
+ private final File folder;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+
+ /**
+ * @param modelType the type of the model
+ * @param folder the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether the lister is enabled
+ */
+ public RegularFilePatternMatchLister(
+ @ModelTypeDef String modelType,
+ File folder,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.folder = Preconditions.checkNotNull(folder);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ if (!folder.isDirectory()) {
+ return ImmutableList.of();
+ }
+ File[] files = folder.listFiles();
+ if (files == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (File file : files) {
+ final Matcher matcher = fileNamePattern.matcher(file.getName());
+ if (!matcher.matches() || !file.isFile()) {
+ continue;
+ }
+ try {
+ modelFilesBuilder.add(ModelFile.createFromRegularFile(file, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromRegularFile with: " + file.getAbsolutePath());
+ }
+ }
+ return modelFilesBuilder.build();
+ }
+ }
+
+ /** Lists the model files preloaded in the APK file. */
+ public static class AssetFilePatternMatchLister implements ModelFileLister {
+ private final AssetManager assetManager;
+ private final String modelType;
+ private final String pathToList;
+ private final Pattern fileNamePattern;
+ private final Supplier<Boolean> isEnabled;
+ private final Object lock = new Object();
+ // Assets won't change without updating the app, so cache the result for performance reason.
+ @GuardedBy("lock")
+ private final Map<String, ImmutableList<ModelFile>> resultCache;
+
+ /**
+ * @param modelType the type of the model.
+ * @param pathToList the folder to list files
+ * @param fileNameRegex the regex to match the file name in the specified folder
+ * @param isEnabled whether this lister is enabled
+ */
+ public AssetFilePatternMatchLister(
+ AssetManager assetManager,
+ @ModelTypeDef String modelType,
+ String pathToList,
+ String fileNameRegex,
+ Supplier<Boolean> isEnabled) {
+ this.assetManager = Preconditions.checkNotNull(assetManager);
+ this.modelType = Preconditions.checkNotNull(modelType);
+ this.pathToList = Preconditions.checkNotNull(pathToList);
+ this.fileNamePattern = Pattern.compile(Preconditions.checkNotNull(fileNameRegex));
+ this.isEnabled = Preconditions.checkNotNull(isEnabled);
+ resultCache = new ArrayMap<>();
+ }
+
+ @Override
+ public ImmutableList<ModelFile> list(@ModelTypeDef String modelType) {
+ if (!this.modelType.equals(modelType)) {
+ return ImmutableList.of();
+ }
+ if (!isEnabled.get()) {
+ return ImmutableList.of();
+ }
+ synchronized (lock) {
+ if (resultCache.get(modelType) != null) {
+ return resultCache.get(modelType);
+ }
+ String[] fileNames = null;
+ try {
+ fileNames = assetManager.list(pathToList);
+ } catch (IOException e) {
+ TcLog.e(TAG, "Failed to list assets", e);
+ }
+ if (fileNames == null) {
+ return ImmutableList.of();
+ }
+ ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
+ for (String fileName : fileNames) {
+ final Matcher matcher = fileNamePattern.matcher(fileName);
+ if (!matcher.matches()) {
+ continue;
+ }
+ String absolutePath =
+ new StringBuilder(pathToList).append('/').append(fileName).toString();
+ try {
+ modelFilesBuilder.add(ModelFile.createFromAsset(assetManager, absolutePath, modelType));
+ } catch (IOException e) {
+ TcLog.w(TAG, "Failed to call createFromAsset with: " + absolutePath);
+ }
+ }
+ ImmutableList<ModelFile> result = modelFilesBuilder.build();
+ resultCache.put(modelType, result);
+ return result;
+ }
+ }
}
/**
* Returns the best model file for the given localelist, {@code null} if nothing is found.
*
* @param modelType the type of model to look up (e.g. annotator, lang_id, etc.)
- * @param localeList an ordered list of user preferences for locales, use {@code null} if there is
- * no preference.
+ * @param localePreferences an ordered list of user preferences for locales, use {@code null} if
+ * there is no preference.
*/
@Nullable
public ModelFile findBestModelFile(
- @ModelType.ModelTypeDef String modelType, @Nullable LocaleList localeList) {
+ @ModelTypeDef String modelType, @Nullable LocaleList localePreferences) {
final String languages =
- localeList == null || localeList.isEmpty()
+ localePreferences == null || localePreferences.isEmpty()
? LocaleList.getDefault().toLanguageTags()
- : localeList.toLanguageTags();
+ : localePreferences.toLanguageTags();
final List<Locale.LanguageRange> languageRangeList = Locale.LanguageRange.parse(languages);
ModelFile bestModel = null;
@@ -131,6 +349,34 @@
}
/**
+ * Deletes model files that are not preferred for any locales in user's preference.
+ *
+ * <p>This method will be invoked as a clean-up after we download a new model successfully. Race
+ * conditions are hard to avoid because we do not hold locks for files. But it should rarely cause
+ * any issues since it's safe to delete a model file in use (b/c we mmap it to memory).
+ */
+ public void deleteUnusedModelFiles() {
+ TcLog.d(TAG, "Start to delete unused model files.");
+ LocaleList localeList = LocaleList.getDefault();
+ for (@ModelTypeDef String modelType : ModelType.values()) {
+ ArraySet<ModelFile> allModelFiles = new ArraySet<>(listModelFiles(modelType));
+ for (int i = 0; i < localeList.size(); i++) {
+ // If a model file is preferred for any local in locale list, then keep it
+ ModelFile bestModel = findBestModelFile(modelType, new LocaleList(localeList.get(i)));
+ allModelFiles.remove(bestModel);
+ }
+ for (ModelFile modelFile : allModelFiles) {
+ if (modelFile.canWrite()) {
+ TcLog.d(TAG, "Deleting model: " + modelFile);
+ if (!modelFile.delete()) {
+ TcLog.w(TAG, "Failed to delete model: " + modelFile);
+ }
+ }
+ }
+ }
+ }
+
+ /**
* Returns a {@link File} that represents the destination to download a model.
*
* <p>Each model file's name is uniquely formatted based on its unique remote manifest URL.
@@ -140,9 +386,9 @@
* @param modelType the type of the model image to download
* @param manifestUrl the unique remote url of the model manifest
*/
- public File getDownloadTargetFile(@ModelType.ModelTypeDef String modelType, String manifestUrl) {
+ public File getDownloadTargetFile(@ModelTypeDef String modelType, String manifestUrl) {
String fileName = String.format("%s.%d.model", modelType, manifestUrl.hashCode());
- return new File(downloadModelDir, fileName);
+ return new File(modelDownloaderDir, fileName);
}
/**
@@ -153,7 +399,7 @@
public void dump(IndentingPrintWriter printWriter) {
printWriter.println("ModelFileManager:");
printWriter.increaseIndent();
- for (@ModelType.ModelTypeDef String modelType : ModelType.values()) {
+ for (@ModelTypeDef String modelType : ModelType.values()) {
printWriter.println(modelType + " model file(s):");
printWriter.increaseIndent();
for (ModelFile modelFile : listModelFiles(modelType)) {
@@ -164,260 +410,102 @@
printWriter.decreaseIndent();
}
- /** Default implementation of the model file supplier. */
- @VisibleForTesting
- static final class ModelFileSupplierImpl implements Supplier<ImmutableList<ModelFile>> {
- private static final String FACTORY_MODEL_DIR = "/etc/textclassifier/";
+ /** Fetch metadata of a model file. */
+ private static class ModelInfoFetcher {
+ private final Function<AssetFileDescriptor, Integer> versionFetcher;
+ private final Function<AssetFileDescriptor, String> supportedLocalesFetcher;
- private static final class ModelFileInfo {
- private final String modelNameRegex;
- private final String configUpdaterModelPath;
- private final Function<Integer, Integer> versionSupplier;
- private final Function<Integer, String> supportedLocalesSupplier;
-
- public ModelFileInfo(
- String modelNameRegex,
- String configUpdaterModelPath,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- this.modelNameRegex = Preconditions.checkNotNull(modelNameRegex);
- this.configUpdaterModelPath = Preconditions.checkNotNull(configUpdaterModelPath);
- this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
- this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
- }
-
- public String getModelNameRegex() {
- return modelNameRegex;
- }
-
- public String getConfigUpdaterModelPath() {
- return configUpdaterModelPath;
- }
-
- public Function<Integer, Integer> getVersionSupplier() {
- return versionSupplier;
- }
-
- public Function<Integer, String> getSupportedLocalesSupplier() {
- return supportedLocalesSupplier;
- }
+ private ModelInfoFetcher(
+ Function<AssetFileDescriptor, Integer> versionFetcher,
+ Function<AssetFileDescriptor, String> supportedLocalesFetcher) {
+ this.versionFetcher = versionFetcher;
+ this.supportedLocalesFetcher = supportedLocalesFetcher;
}
- private static final ImmutableMap<String, ModelFileInfo> MODEL_FILE_INFO_MAP =
- ImmutableMap.<String, ModelFileInfo>builder()
- .put(
- ModelType.ANNOTATOR,
- new ModelFileInfo(
- "(annotator|textclassifier)\\.(.*)\\.model",
- "/data/misc/textclassifier/textclassifier.model",
- AnnotatorModel::getVersion,
- AnnotatorModel::getLocales))
- .put(
- ModelType.LANG_ID,
- new ModelFileInfo(
- "lang_id.model",
- "/data/misc/textclassifier/lang_id.model",
- LangIdModel::getVersion,
- fd -> ModelFile.LANGUAGE_INDEPENDENT))
- .put(
- ModelType.ACTIONS_SUGGESTIONS,
- new ModelFileInfo(
- "actions_suggestions\\.(.*)\\.model",
- "/data/misc/textclassifier/actions_suggestions.model",
- ActionsSuggestionsModel::getVersion,
- ActionsSuggestionsModel::getLocales))
- .build();
-
- private final TextClassifierSettings settings;
- @ModelType.ModelTypeDef private final String modelType;
- private final File configUpdaterModelFile;
- private final File downloaderModelDir;
- private final File factoryModelDir;
- private final Pattern modelFilenamePattern;
- private final Function<Integer, Integer> versionSupplier;
- private final Function<Integer, String> supportedLocalesSupplier;
- private final Object lock = new Object();
-
- @GuardedBy("lock")
- private ImmutableList<ModelFile> factoryModels;
-
- public ModelFileSupplierImpl(
- TextClassifierSettings settings,
- @ModelType.ModelTypeDef String modelType,
- File downloaderModelDir) {
- this(
- settings,
- modelType,
- new File(FACTORY_MODEL_DIR),
- MODEL_FILE_INFO_MAP.get(modelType).getModelNameRegex(),
- new File(MODEL_FILE_INFO_MAP.get(modelType).getConfigUpdaterModelPath()),
- downloaderModelDir,
- MODEL_FILE_INFO_MAP.get(modelType).getVersionSupplier(),
- MODEL_FILE_INFO_MAP.get(modelType).getSupportedLocalesSupplier());
+ int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return versionFetcher.apply(assetFileDescriptor);
}
- @VisibleForTesting
- ModelFileSupplierImpl(
- TextClassifierSettings settings,
- @ModelType.ModelTypeDef String modelType,
- File factoryModelDir,
- String modelFileNameRegex,
- File configUpdaterModelFile,
- File downloaderModelDir,
- Function<Integer, Integer> versionSupplier,
- Function<Integer, String> supportedLocalesSupplier) {
- this.settings = Preconditions.checkNotNull(settings);
- this.modelType = Preconditions.checkNotNull(modelType);
- this.factoryModelDir = Preconditions.checkNotNull(factoryModelDir);
- this.modelFilenamePattern = Pattern.compile(Preconditions.checkNotNull(modelFileNameRegex));
- this.configUpdaterModelFile = Preconditions.checkNotNull(configUpdaterModelFile);
- this.downloaderModelDir = Preconditions.checkNotNull(downloaderModelDir);
- this.versionSupplier = Preconditions.checkNotNull(versionSupplier);
- this.supportedLocalesSupplier = Preconditions.checkNotNull(supportedLocalesSupplier);
+ String getSupportedLocales(AssetFileDescriptor assetFileDescriptor) {
+ return supportedLocalesFetcher.apply(assetFileDescriptor);
}
- @Override
- public ImmutableList<ModelFile> get() {
- final List<ModelFile> modelFiles = new ArrayList<>();
- // The dwonloader and config updater model have higher precedences.
- if (downloaderModelDir.exists() && settings.isModelDownloadManagerEnabled()) {
- modelFiles.addAll(getMatchedModelFiles(downloaderModelDir));
+ static ModelInfoFetcher create(@ModelTypeDef String modelType) {
+ switch (modelType) {
+ case ModelType.ANNOTATOR:
+ return new ModelInfoFetcher(AnnotatorModel::getVersion, AnnotatorModel::getLocales);
+ case ModelType.ACTIONS_SUGGESTIONS:
+ return new ModelInfoFetcher(
+ ActionsSuggestionsModel::getVersion, ActionsSuggestionsModel::getLocales);
+ case ModelType.LANG_ID:
+ return new ModelInfoFetcher(
+ LangIdModel::getVersion, afd -> ModelFile.LANGUAGE_INDEPENDENT);
+ default: // fall out
}
- if (configUpdaterModelFile.exists()) {
- final ModelFile updatedModel = createModelFile(configUpdaterModelFile);
- if (updatedModel != null) {
- modelFiles.add(updatedModel);
- }
- }
- // Factory models should never have overlapping locales, so the order doesn't matter.
- synchronized (lock) {
- if (factoryModels == null) {
- factoryModels = getMatchedModelFiles(factoryModelDir);
- }
- modelFiles.addAll(factoryModels);
- }
- return ImmutableList.copyOf(modelFiles);
- }
-
- private ImmutableList<ModelFile> getMatchedModelFiles(File parentDir) {
- ImmutableList.Builder<ModelFile> modelFilesBuilder = ImmutableList.builder();
- if (parentDir.exists() && parentDir.isDirectory()) {
- final File[] files = parentDir.listFiles();
- for (File file : files) {
- final Matcher matcher = modelFilenamePattern.matcher(file.getName());
- if (matcher.matches() && file.isFile()) {
- final ModelFile model = createModelFile(file);
- if (model != null) {
- modelFilesBuilder.add(model);
- }
- }
- }
- }
- return modelFilesBuilder.build();
- }
-
- /** Returns null if the path did not point to a compatible model. */
- @Nullable
- private ModelFile createModelFile(File file) {
- if (!file.exists()) {
- return null;
- }
- ParcelFileDescriptor modelFd = null;
- try {
- modelFd = ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
- if (modelFd == null) {
- return null;
- }
- final int modelFdInt = modelFd.getFd();
- final int version = versionSupplier.apply(modelFdInt);
- final String supportedLocalesStr = supportedLocalesSupplier.apply(modelFdInt);
- if (supportedLocalesStr.isEmpty()) {
- TcLog.d(TAG, "Ignoring " + file.getAbsolutePath());
- return null;
- }
- final List<Locale> supportedLocales = new ArrayList<>();
- for (String langTag : Splitter.on(',').split(supportedLocalesStr)) {
- supportedLocales.add(Locale.forLanguageTag(langTag));
- }
- return new ModelFile(
- modelType,
- file,
- version,
- supportedLocales,
- supportedLocalesStr,
- ModelFile.LANGUAGE_INDEPENDENT.equals(supportedLocalesStr));
- } catch (FileNotFoundException e) {
- TcLog.e(TAG, "Failed to find " + file.getAbsolutePath(), e);
- return null;
- } finally {
- maybeCloseAndLogError(modelFd);
- }
- }
-
- /** Closes the ParcelFileDescriptor, if non-null, and logs any errors that occur. */
- private static void maybeCloseAndLogError(@Nullable ParcelFileDescriptor fd) {
- if (fd == null) {
- return;
- }
- try {
- fd.close();
- } catch (IOException e) {
- TcLog.e(TAG, "Error closing file.", e);
- }
+ throw new IllegalStateException("Unsupported model types");
}
}
/** Describes TextClassifier model files on disk. */
- public static final class ModelFile {
- public static final String LANGUAGE_INDEPENDENT = "*";
+ public static class ModelFile {
+ @VisibleForTesting static final String LANGUAGE_INDEPENDENT = "*";
- @ModelType.ModelTypeDef private final String modelType;
- private final File file;
- private final int version;
- private final List<Locale> supportedLocales;
- private final String supportedLocalesStr;
- private final boolean languageIndependent;
+ @ModelTypeDef public final String modelType;
+ public final String absolutePath;
+ public final int version;
+ public final LocaleList supportedLocales;
+ public final boolean languageIndependent;
+ public final boolean isAsset;
- public ModelFile(
- @ModelType.ModelTypeDef String modelType,
- File file,
+ public static ModelFile createFromRegularFile(File file, @ModelTypeDef String modelType)
+ throws IOException {
+ ParcelFileDescriptor pfd =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ try (AssetFileDescriptor afd = new AssetFileDescriptor(pfd, 0, file.length())) {
+ return createFromAssetFileDescriptor(
+ file.getAbsolutePath(), modelType, afd, /* isAsset= */ false);
+ }
+ }
+
+ public static ModelFile createFromAsset(
+ AssetManager assetManager, String absolutePath, @ModelTypeDef String modelType)
+ throws IOException {
+ try (AssetFileDescriptor assetFileDescriptor = assetManager.openFd(absolutePath)) {
+ return createFromAssetFileDescriptor(
+ absolutePath, modelType, assetFileDescriptor, /* isAsset= */ true);
+ }
+ }
+
+ private static ModelFile createFromAssetFileDescriptor(
+ String absolutePath,
+ @ModelTypeDef String modelType,
+ AssetFileDescriptor assetFileDescriptor,
+ boolean isAsset) {
+ ModelInfoFetcher modelInfoFetcher = ModelInfoFetcher.create(modelType);
+ return new ModelFile(
+ modelType,
+ absolutePath,
+ modelInfoFetcher.getVersion(assetFileDescriptor),
+ modelInfoFetcher.getSupportedLocales(assetFileDescriptor),
+ isAsset);
+ }
+
+ @VisibleForTesting
+ ModelFile(
+ @ModelTypeDef String modelType,
+ String absolutePath,
int version,
- List<Locale> supportedLocales,
- String supportedLocalesStr,
- boolean languageIndependent) {
- this.modelType = Preconditions.checkNotNull(modelType);
- this.file = Preconditions.checkNotNull(file);
+ String supportedLocaleTags,
+ boolean isAsset) {
+ this.modelType = modelType;
+ this.absolutePath = absolutePath;
this.version = version;
- this.supportedLocales = Preconditions.checkNotNull(supportedLocales);
- this.supportedLocalesStr = Preconditions.checkNotNull(supportedLocalesStr);
- this.languageIndependent = languageIndependent;
- }
-
- /** Returns the type of this model, defined in {@link ModelType}. */
- @ModelType.ModelTypeDef
- public String getModelType() {
- return modelType;
- }
-
- /** Returns the absolute path to the model file. */
- public String getPath() {
- return file.getAbsolutePath();
- }
-
- /** Returns a name to use for id generation, effectively the name of the model file. */
- public String getName() {
- return file.getName();
- }
-
- /** Returns the version tag in the model's metadata. */
- public int getVersion() {
- return version;
- }
-
- /** Returns whether the language supports any language in the given ranges. */
- public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
- Preconditions.checkNotNull(languageRanges);
- return languageIndependent || Locale.lookup(languageRanges, supportedLocales) != null;
+ this.languageIndependent = LANGUAGE_INDEPENDENT.equals(supportedLocaleTags);
+ this.supportedLocales =
+ languageIndependent
+ ? LocaleList.getEmptyLocaleList()
+ : LocaleList.forLanguageTags(supportedLocaleTags);
+ this.isAsset = isAsset;
}
/** Returns if this model file is preferred to the given one. */
@@ -437,70 +525,111 @@
}
// A higher-version model is preferred.
- if (version > model.getVersion()) {
+ if (version > model.version) {
return true;
}
return false;
}
+ /** Returns whether the language supports any language in the given ranges. */
+ public boolean isAnyLanguageSupported(List<Locale.LanguageRange> languageRanges) {
+ Preconditions.checkNotNull(languageRanges);
+ if (languageIndependent) {
+ return true;
+ }
+ List<String> supportedLocaleTags =
+ Arrays.asList(supportedLocales.toLanguageTags().split(","));
+ return Locale.lookupTag(languageRanges, supportedLocaleTags) != null;
+ }
+
+ public AssetFileDescriptor open(AssetManager assetManager) throws IOException {
+ if (isAsset) {
+ return assetManager.openFd(absolutePath);
+ }
+ File file = new File(absolutePath);
+ ParcelFileDescriptor parcelFileDescriptor =
+ ParcelFileDescriptor.open(file, ParcelFileDescriptor.MODE_READ_ONLY);
+ return new AssetFileDescriptor(parcelFileDescriptor, 0, file.length());
+ }
+
+ public boolean canWrite() {
+ if (isAsset) {
+ return false;
+ }
+ return new File(absolutePath).canWrite();
+ }
+
+ public boolean delete() {
+ if (isAsset) {
+ throw new IllegalStateException("asset is read-only, deleting it is not allowed.");
+ }
+ return new File(absolutePath).delete();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (!(o instanceof ModelFile)) {
+ return false;
+ }
+ ModelFile modelFile = (ModelFile) o;
+ return version == modelFile.version
+ && languageIndependent == modelFile.languageIndependent
+ && isAsset == modelFile.isAsset
+ && Objects.equals(modelType, modelFile.modelType)
+ && Objects.equals(absolutePath, modelFile.absolutePath)
+ && Objects.equals(supportedLocales, modelFile.supportedLocales);
+ }
+
@Override
public int hashCode() {
- return Objects.hash(getPath());
- }
-
- @Override
- public boolean equals(Object other) {
- if (this == other) {
- return true;
- }
- if (other instanceof ModelFile) {
- final ModelFile otherModel = (ModelFile) other;
- return TextUtils.equals(getPath(), otherModel.getPath());
- }
- return false;
+ return Objects.hash(
+ modelType, absolutePath, version, supportedLocales, languageIndependent, isAsset);
}
public ModelInfo toModelInfo() {
- return new ModelInfo(getVersion(), supportedLocalesStr);
+ return new ModelInfo(version, supportedLocales.toLanguageTags());
}
@Override
public String toString() {
return String.format(
Locale.US,
- "ModelFile { type=%s path=%s name=%s version=%d locales=%s }",
+ "ModelFile { type=%s path=%s version=%d locales=%s isAsset=%b}",
modelType,
- getPath(),
- getName(),
+ absolutePath,
version,
- supportedLocalesStr);
+ languageIndependent ? LANGUAGE_INDEPENDENT : supportedLocales.toLanguageTags(),
+ isAsset);
}
public static ImmutableList<Optional<ModelInfo>> toModelInfos(
- Optional<ModelFile>... modelFiles) {
+ Optional<ModelFileManager.ModelFile>... modelFiles) {
return Arrays.stream(modelFiles)
- .map(modelFile -> modelFile.transform(ModelFile::toModelInfo))
+ .map(modelFile -> modelFile.transform(ModelFileManager.ModelFile::toModelInfo))
.collect(Collectors.collectingAndThen(Collectors.toList(), ImmutableList::copyOf));
}
+ }
- /** Effectively an enum class to represent types of models. */
- public static final class ModelType {
- @Retention(RetentionPolicy.SOURCE)
- @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
- public @interface ModelTypeDef {}
+ /** Effectively an enum class to represent types of models. */
+ public static final class ModelType {
+ @Retention(RetentionPolicy.SOURCE)
+ @StringDef({ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS})
+ @interface ModelTypeDef {}
- public static final String ANNOTATOR = "annotator";
- public static final String LANG_ID = "lang_id";
- public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
+ public static final String ANNOTATOR = "annotator";
+ public static final String LANG_ID = "lang_id";
+ public static final String ACTIONS_SUGGESTIONS = "actions_suggestions";
- public static final ImmutableList<String> VALUES =
- ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
+ public static final ImmutableList<String> VALUES =
+ ImmutableList.of(ANNOTATOR, LANG_ID, ACTIONS_SUGGESTIONS);
- public static ImmutableList<String> values() {
- return VALUES;
- }
-
- private ModelType() {}
+ public static ImmutableList<String> values() {
+ return VALUES;
}
+
+ private ModelType() {}
}
}
diff --git a/java/src/com/android/textclassifier/TextClassifierImpl.java b/java/src/com/android/textclassifier/TextClassifierImpl.java
index dbe66f6..b824ed0 100644
--- a/java/src/com/android/textclassifier/TextClassifierImpl.java
+++ b/java/src/com/android/textclassifier/TextClassifierImpl.java
@@ -22,6 +22,7 @@
import android.app.RemoteAction;
import android.content.Context;
import android.content.Intent;
+import android.content.res.AssetFileDescriptor;
import android.icu.util.ULocale;
import android.os.Bundle;
import android.os.LocaleList;
@@ -42,7 +43,7 @@
import androidx.annotation.WorkerThread;
import androidx.core.util.Pair;
import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
import com.android.textclassifier.common.base.TcLog;
import com.android.textclassifier.common.intent.LabeledIntent;
import com.android.textclassifier.common.intent.TemplateIntentFactory;
@@ -61,6 +62,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableList;
+import java.io.IOException;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.ArrayList;
@@ -123,7 +125,7 @@
}
@WorkerThread
- TextSelection suggestSelection(TextSelection.Request request) {
+ TextSelection suggestSelection(TextSelection.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
final int rangeLength = request.getEndIndex() - request.getStartIndex();
@@ -182,7 +184,7 @@
}
@WorkerThread
- TextClassification classifyText(TextClassification.Request request) {
+ TextClassification classifyText(TextClassification.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
LangIdModel langId = getLangIdImpl();
@@ -222,7 +224,7 @@
}
@WorkerThread
- TextLinks generateLinks(TextLinks.Request request) {
+ TextLinks generateLinks(TextLinks.Request request) throws IOException {
Preconditions.checkNotNull(request);
Preconditions.checkArgument(
request.getText().length() <= getMaxGenerateLinksTextLength(),
@@ -332,7 +334,7 @@
TextClassifierEventConverter.fromPlatform(event));
}
- TextLanguage detectLanguage(TextLanguage.Request request) {
+ TextLanguage detectLanguage(TextLanguage.Request request) throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
final TextLanguage.Builder builder = new TextLanguage.Builder();
@@ -345,7 +347,8 @@
return builder.build();
}
- ConversationActions suggestConversationActions(ConversationActions.Request request) {
+ ConversationActions suggestConversationActions(ConversationActions.Request request)
+ throws IOException {
Preconditions.checkNotNull(request);
checkMainThread();
ActionsSuggestionsModel actionsImpl = getActionsImpl();
@@ -431,7 +434,7 @@
return request.getTypeConfig().resolveEntityListModifications(defaultActionTypes);
}
- private AnnotatorModel getAnnotatorImpl(LocaleList localeList) {
+ private AnnotatorModel getAnnotatorImpl(LocaleList localeList) throws IOException {
synchronized (lock) {
localeList = localeList == null ? LocaleList.getDefault() : localeList;
final ModelFileManager.ModelFile bestModel =
@@ -444,31 +447,35 @@
// The current annotator model may be still used by another thread / model.
// Do not call close() here, and let the GC to clean it up when no one else
// is using it.
- annotatorImpl = new AnnotatorModel(bestModel.getPath());
- annotatorImpl.setLangIdModel(getLangIdImpl());
- annotatorModelInUse = bestModel;
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ annotatorImpl = new AnnotatorModel(afd);
+ annotatorImpl.setLangIdModel(getLangIdImpl());
+ annotatorModelInUse = bestModel;
+ }
}
return annotatorImpl;
}
}
- private LangIdModel getLangIdImpl() {
+ private LangIdModel getLangIdImpl() throws IOException {
synchronized (lock) {
final ModelFileManager.ModelFile bestModel =
- modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localeList= */ null);
+ modelFileManager.findBestModelFile(ModelType.LANG_ID, /* localePreferences= */ null);
if (bestModel == null) {
throw new IllegalStateException("Failed to find the best LangID model.");
}
if (langIdImpl == null || !Objects.equals(langIdModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- langIdImpl = new LangIdModel(bestModel.getPath());
- langIdModelInUse = bestModel;
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ langIdImpl = new LangIdModel(afd);
+ langIdModelInUse = bestModel;
+ }
}
return langIdImpl;
}
}
- private ActionsSuggestionsModel getActionsImpl() {
+ private ActionsSuggestionsModel getActionsImpl() throws IOException {
synchronized (lock) {
// TODO: Use LangID to determine the locale we should use here?
final ModelFileManager.ModelFile bestModel =
@@ -479,8 +486,10 @@
}
if (actionsImpl == null || !Objects.equals(actionModelInUse, bestModel)) {
TcLog.d(TAG, "Loading " + bestModel);
- actionsImpl = new ActionsSuggestionsModel(bestModel.getPath());
- actionModelInUse = bestModel;
+ try (AssetFileDescriptor afd = bestModel.open(context.getAssets())) {
+ actionsImpl = new ActionsSuggestionsModel(afd);
+ actionModelInUse = bestModel;
+ }
}
return actionsImpl;
}
diff --git a/java/src/com/android/textclassifier/TextClassifierSettings.java b/java/src/com/android/textclassifier/TextClassifierSettings.java
index c13ff68..b13d166 100644
--- a/java/src/com/android/textclassifier/TextClassifierSettings.java
+++ b/java/src/com/android/textclassifier/TextClassifierSettings.java
@@ -20,7 +20,8 @@
import android.view.textclassifier.ConversationAction;
import android.view.textclassifier.TextClassifier;
import androidx.annotation.NonNull;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
import com.android.textclassifier.utils.IndentingPrintWriter;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
@@ -44,6 +45,7 @@
* @see android.provider.DeviceConfig#NAMESPACE_TEXTCLASSIFIER
*/
public final class TextClassifierSettings {
+ private static final String TAG = "TextClassifierSettings";
public static final String NAMESPACE = DeviceConfig.NAMESPACE_TEXTCLASSIFIER;
private static final String DELIMITER = ":";
@@ -109,6 +111,9 @@
/** Whether to enable model downloading with ModelDownloadManager */
@VisibleForTesting
static final String MODEL_DOWNLOAD_MANAGER_ENABLED = "model_download_manager_enabled";
+ /** Type of network to download model manifest. A String value of androidx.work.NetworkType. */
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE =
+ "manifest_download_required_network_type";
/** Max attempts allowed for a single ModelDownloader downloading task. */
@VisibleForTesting
static final String MODEL_DOWNLOAD_MAX_ATTEMPTS = "model_download_max_attempts";
@@ -196,6 +201,8 @@
private static final boolean TRANSLATE_IN_CLASSIFICATION_ENABLED_DEFAULT = true;
private static final boolean DETECT_LANGUAGES_FROM_TEXT_ENABLED_DEFAULT = true;
private static final boolean MODEL_DOWNLOAD_MANAGER_ENABLED_DEFAULT = false;
+ // Manifest files are usually small, default to any network type
+ private static final String MANIFEST_DOWNLOAD_REQUIRED_NETWORK_TYPE_DEFAULT = "NOT_ROAMING";
private static final int MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT = 5;
private static final String ANNOTATOR_URL_PREFIX_DEFAULT =
"https://www.gstatic.com/android/text_classifier/";
@@ -360,7 +367,7 @@
NAMESPACE, MODEL_DOWNLOAD_MAX_ATTEMPTS, MODEL_DOWNLOAD_MAX_ATTEMPTS_DEFAULT);
}
- public String getModelURLPrefix(@ModelType.ModelTypeDef String modelType) {
+ public String getModelURLPrefix(@ModelTypeDef String modelType) {
switch (modelType) {
case ModelType.ANNOTATOR:
return deviceConfig.getString(
@@ -375,7 +382,7 @@
}
}
- public String getPrimaryModelURLSuffix(@ModelType.ModelTypeDef String modelType) {
+ public String getPrimaryModelURLSuffix(@ModelTypeDef String modelType) {
switch (modelType) {
case ModelType.ANNOTATOR:
return deviceConfig.getString(
diff --git a/java/src/com/android/textclassifier/common/base/TcLog.java b/java/src/com/android/textclassifier/common/base/TcLog.java
index 87f1187..05a2443 100644
--- a/java/src/com/android/textclassifier/common/base/TcLog.java
+++ b/java/src/com/android/textclassifier/common/base/TcLog.java
@@ -16,6 +16,8 @@
package com.android.textclassifier.common.base;
+import android.util.Log;
+
/**
* Logging for android.view.textclassifier package.
*
@@ -31,27 +33,30 @@
public static final String TAG = "androidtc";
/** true: Enables full logging. false: Limits logging to debug level. */
- public static final boolean ENABLE_FULL_LOGGING =
- android.util.Log.isLoggable(TAG, android.util.Log.VERBOSE);
+ public static final boolean ENABLE_FULL_LOGGING = Log.isLoggable(TAG, Log.VERBOSE);
private TcLog() {}
public static void v(String tag, String msg) {
if (ENABLE_FULL_LOGGING) {
- android.util.Log.v(getTag(tag), msg);
+ Log.v(getTag(tag), msg);
}
}
public static void d(String tag, String msg) {
- android.util.Log.d(getTag(tag), msg);
+ Log.d(getTag(tag), msg);
}
public static void w(String tag, String msg) {
- android.util.Log.w(getTag(tag), msg);
+ Log.w(getTag(tag), msg);
+ }
+
+ public static void e(String tag, String msg) {
+ Log.e(getTag(tag), msg);
}
public static void e(String tag, String msg, Throwable tr) {
- android.util.Log.e(getTag(tag), msg, tr);
+ Log.e(getTag(tag), msg, tr);
}
private static String getTag(String customTag) {
diff --git a/java/tests/instrumentation/Android.bp b/java/tests/instrumentation/Android.bp
index a0cd0ec..fa31894 100644
--- a/java/tests/instrumentation/Android.bp
+++ b/java/tests/instrumentation/Android.bp
@@ -55,5 +55,5 @@
instrumentation_for: "TextClassifierService",
- data: ["testdata/*"]
+ data: ["testdata/*"],
}
\ No newline at end of file
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
index 09bd363..20fa508 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/DefaultTextClassifierServiceTest.java
@@ -43,7 +43,6 @@
import com.android.textclassifier.common.statsd.TextClassifierApiUsageLogger;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import java.util.List;
@@ -190,7 +189,7 @@
@Test
public void missingModelFile_onFailureShouldBeCalled() throws Exception {
testInjector.setModelFileManager(
- new ModelFileManager(TestDataUtils.getTestDataFolder(), ImmutableMap.of()));
+ new ModelFileManager(ApplicationProvider.getApplicationContext(), ImmutableList.of()));
defaultTextClassifierService.onCreate();
TextClassification.Request request = new TextClassification.Request.Builder("hi", 0, 2).build();
@@ -231,9 +230,14 @@
}
@Override
+ public Context getContext() {
+ return context;
+ }
+
+ @Override
public ModelFileManager createModelFileManager(TextClassifierSettings settings) {
if (modelFileManager == null) {
- return TestDataUtils.createModelFileManagerForTesting();
+ return TestDataUtils.createModelFileManagerForTesting(context);
}
return modelFileManager;
}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
index 8ef3908..de819ef 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/ModelFileManagerTest.java
@@ -16,29 +16,31 @@
package com.android.textclassifier;
+import static com.android.textclassifier.ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT;
import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.ArgumentMatchers.anyBoolean;
-import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.when;
import android.os.LocaleList;
import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.ModelFileManager.ModelFile;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType.ModelTypeDef;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
+import com.android.textclassifier.ModelFileManager.RegularFilePatternMatchLister;
import com.android.textclassifier.common.logging.ResultIdUtils.ModelInfo;
+import com.android.textclassifier.testing.SetDefaultLocalesRule;
import com.google.common.base.Optional;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
+import com.google.common.io.Files;
import java.io.File;
import java.io.IOException;
-import java.util.Collections;
import java.util.List;
import java.util.Locale;
-import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
@@ -51,22 +53,16 @@
private static final String URL = "http://www.gstatic.com/android/text_classifier/q/711/en.fb";
private static final String URL_2 = "http://www.gstatic.com/android/text_classifier/q/712/en.fb";
- @ModelFile.ModelType.ModelTypeDef
- private static final String MODEL_TYPE = ModelFile.ModelType.ANNOTATOR;
+ @ModelTypeDef private static final String MODEL_TYPE = ModelType.ANNOTATOR;
- @ModelFile.ModelType.ModelTypeDef
- private static final String MODEL_TYPE_2 = ModelFile.ModelType.LANG_ID;
+ @ModelTypeDef private static final String MODEL_TYPE_2 = ModelType.LANG_ID;
- @Mock private Supplier<ImmutableList<ModelFile>> modelFileSupplier;
@Mock private TextClassifierSettings.IDeviceConfig mockDeviceConfig;
- private File rootTestDir;
- private File factoryModelDir;
- private File configUpdaterModelFile;
- private File downloaderModelDir;
+ @Rule public final SetDefaultLocalesRule setDefaultLocalesRule = new SetDefaultLocalesRule();
+ private File rootTestDir;
private ModelFileManager modelFileManager;
- private ModelFileManager.ModelFileSupplierImpl modelFileSupplierImpl;
@Before
public void setup() {
@@ -74,28 +70,11 @@
rootTestDir =
new File(ApplicationProvider.getApplicationContext().getCacheDir(), "rootTestDir");
- factoryModelDir = new File(rootTestDir, "factory");
- configUpdaterModelFile = new File(rootTestDir, "configupdater.model");
- downloaderModelDir = new File(rootTestDir, "downloader");
-
- modelFileManager =
- new ModelFileManager(downloaderModelDir, ImmutableMap.of(MODEL_TYPE, modelFileSupplier));
- modelFileSupplierImpl =
- new ModelFileManager.ModelFileSupplierImpl(
- new TextClassifierSettings(mockDeviceConfig),
- MODEL_TYPE,
- factoryModelDir,
- "test\\d.model",
- configUpdaterModelFile,
- downloaderModelDir,
- fd -> 1,
- fd -> ModelFileManager.ModelFile.LANGUAGE_INDEPENDENT);
-
rootTestDir.mkdirs();
- factoryModelDir.mkdirs();
- downloaderModelDir.mkdirs();
-
- Locale.setDefault(DEFAULT_LOCALE);
+ modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ new TextClassifierSettings(mockDeviceConfig));
}
@After
@@ -104,100 +83,106 @@
}
@Test
- public void get() {
- ModelFileManager.ModelFile modelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(modelFile));
+ public void annotatorModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.ANNOTATOR, "textclassifier/annotator.universal.model");
+ }
- List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(MODEL_TYPE);
+ @Test
+ public void actionsModelPreloaded() {
+ verifyModelPreloadedAsAsset(
+ ModelType.ACTIONS_SUGGESTIONS, "textclassifier/actions_suggestions.universal.model");
+ }
- assertThat(modelFiles).hasSize(1);
- assertThat(modelFiles.get(0)).isEqualTo(modelFile);
+ @Test
+ public void langIdModelPreloaded() {
+ verifyModelPreloadedAsAsset(ModelType.LANG_ID, "textclassifier/lang_id.model");
+ }
+
+ private void verifyModelPreloadedAsAsset(
+ @ModelTypeDef String modelType, String expectedModelPath) {
+ List<ModelFileManager.ModelFile> modelFiles = modelFileManager.listModelFiles(modelType);
+ List<ModelFile> assetFiles =
+ modelFiles.stream().filter(modelFile -> modelFile.isAsset).collect(Collectors.toList());
+
+ assertThat(assetFiles).hasSize(1);
+ assertThat(assetFiles.get(0).absolutePath).isEqualTo(expectedModelPath);
}
@Test
public void findBestModel_versionCode() {
ModelFileManager.ModelFile olderModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile newerModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get()).thenReturn(ImmutableList.of(olderModelFile, newerModelFile));
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(olderModelFile, newerModelFile)));
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.getEmptyLocaleList());
-
+ ModelFile bestModelFile = modelFileManager.findBestModelFile(MODEL_TYPE, null);
assertThat(bestModelFile).isEqualTo(newerModelFile);
}
@Test
public void findBestModel_languageDependentModelIsPreferred() {
- Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
+ DEFAULT_LOCALE.toLanguageTag(),
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(
- MODEL_TYPE, LocaleList.forLanguageTags(locale.toLanguageTag()));
+ ModelFile bestModelFile =
+ modelFileManager.findBestModelFile(MODEL_TYPE, new LocaleList(DEFAULT_LOCALE));
assertThat(bestModelFile).isEqualTo(languageDependentModelFile);
}
@Test
public void findBestModel_noMatchedLanguageModel() {
- Locale locale = Locale.forLanguageTag("ja");
ModelFileManager.ModelFile languageIndependentModelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
+ MODEL_TYPE,
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageDependentModelFile =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(locale),
- locale.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
-
- ModelFileManager.ModelFile bestModelFile =
- modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
- assertThat(bestModelFile).isEqualTo(languageIndependentModelFile);
- }
-
- @Test
- public void findBestModel_noMatchedLanguageModel_defaultLocaleModelExists() {
- ModelFileManager.ModelFile languageIndependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(), "", true);
-
- ModelFileManager.ModelFile languageDependentModelFile =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(DEFAULT_LOCALE),
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 2,
DEFAULT_LOCALE.toLanguageTag(),
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(languageIndependentModelFile, languageDependentModelFile));
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType ->
+ ImmutableList.of(languageIndependentModelFile, languageDependentModelFile)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("zh-hk"));
@@ -209,23 +194,22 @@
ModelFileManager.ModelFile matchButOlderModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("fr")),
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
"fr",
- false);
-
+ /* isAsset= */ false);
ModelFileManager.ModelFile mismatchButNewerModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
"ja",
- false);
-
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchButOlderModel, mismatchButNewerModel));
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchButOlderModel, mismatchButNewerModel)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("fr"));
@@ -233,21 +217,26 @@
}
@Test
- public void findBestModel_languageIsMoreImportantThanVersion_bestModelComesFirst() {
+ public void findBestModel_preferMatchedLocaleModel() {
ModelFileManager.ModelFile matchLocaleModel =
new ModelFileManager.ModelFile(
MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
+ new File(rootTestDir, "a").getAbsolutePath(),
+ /* version= */ 1,
"ja",
- false);
-
+ /* isAsset= */ false);
ModelFileManager.ModelFile languageIndependentModel =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(), "", true);
- when(modelFileSupplier.get())
- .thenReturn(ImmutableList.of(matchLocaleModel, languageIndependentModel));
+ MODEL_TYPE,
+ new File(rootTestDir, "b").getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(
+ modelType -> ImmutableList.of(matchLocaleModel, languageIndependentModel)));
ModelFileManager.ModelFile bestModelFile =
modelFileManager.findBestModelFile(MODEL_TYPE, LocaleList.forLanguageTags("ja"));
@@ -256,9 +245,135 @@
}
@Test
+ public void deleteUnusedModelFiles_olderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "ja", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_languageIndependentOlderModelDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model1.getAbsolutePath(),
+ /* version= */ 1,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE,
+ model2.getAbsolutePath(),
+ /* version= */ 2,
+ LANGUAGE_INDEPENDENT,
+ /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isFalse();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_modelOnlySupportingLocalesNotInListDeleted() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 1, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isFalse();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_multiLocalesInLocaleList() throws Exception {
+ File model1 = new File(rootTestDir, "model1.fb");
+ model1.createNewFile();
+ File model2 = new File(rootTestDir, "model2.fb");
+ model2.createNewFile();
+ ModelFileManager.ModelFile modelFile1 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager.ModelFile modelFile2 =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model2.getAbsolutePath(), /* version= */ 2, "en", /* isAsset= */ false);
+ setDefaultLocalesRule.set(
+ new LocaleList(Locale.forLanguageTag("ja"), Locale.forLanguageTag("en")));
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile1, modelFile2)));
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ assertThat(model2.exists()).isTrue();
+ }
+
+ @Test
+ public void deleteUnusedModelFiles_readOnlyModelsUntouched() throws Exception {
+ File readOnlyDir = new File(rootTestDir, "read_only/");
+ readOnlyDir.mkdirs();
+ File model1 = new File(readOnlyDir, "model1.fb");
+ model1.createNewFile();
+ readOnlyDir.setWritable(false);
+ ModelFileManager.ModelFile modelFile =
+ new ModelFileManager.ModelFile(
+ MODEL_TYPE, model1.getAbsolutePath(), /* version= */ 1, "ja", /* isAsset= */ false);
+ ModelFileManager modelFileManager =
+ new ModelFileManager(
+ ApplicationProvider.getApplicationContext(),
+ ImmutableList.of(modelType -> ImmutableList.of(modelFile)));
+ setDefaultLocalesRule.set(new LocaleList(Locale.forLanguageTag("ja")));
+
+ modelFileManager.deleteUnusedModelFiles();
+
+ assertThat(model1.exists()).isTrue();
+ }
+
+ @Test
public void getDownloadTargetFile_targetFileInCorrectDir() {
File targetFile = modelFileManager.getDownloadTargetFile(MODEL_TYPE, URL);
- assertThat(targetFile.getParentFile()).isEqualTo(downloaderModelDir);
+ assertThat(targetFile.getAbsolutePath())
+ .startsWith(ApplicationProvider.getApplicationContext().getFilesDir().getAbsolutePath());
}
@Test
@@ -278,21 +393,11 @@
public void modelFileEquals() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA).isEqualTo(modelB);
}
@@ -301,67 +406,23 @@
public void modelFile_different() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/b"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA).isNotEqualTo(modelB);
}
@Test
- public void modelFile_getPath() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getPath()).isEqualTo("/path/a");
- }
-
- @Test
- public void modelFile_getName() {
- ModelFileManager.ModelFile modelA =
- new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
-
- assertThat(modelA.getName()).isEqualTo("a");
- }
-
- @Test
public void modelFile_isPreferredTo_languageDependentIsBetter() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 1,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 1, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 2, ImmutableList.of(), "", true);
+ MODEL_TYPE, "/path/b", /* version= */ 2, LANGUAGE_INDEPENDENT, /* isAsset= */ false);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -370,16 +431,11 @@
public void modelFile_isPreferredTo_version() {
ModelFileManager.ModelFile modelA =
new ModelFileManager.ModelFile(
- MODEL_TYPE,
- new File("/path/a"),
- 2,
- Collections.singletonList(Locale.forLanguageTag("ja")),
- "ja",
- false);
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ModelFileManager.ModelFile modelB =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/b"), 1, ImmutableList.of(), "", false);
+ MODEL_TYPE, "/path/b", /* version= */ 1, "ja", /* isAsset= */ false);
assertThat(modelA.isPreferredTo(modelB)).isTrue();
}
@@ -388,7 +444,7 @@
public void modelFile_toModelInfo() {
ModelFileManager.ModelFile modelFile =
new ModelFileManager.ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ModelInfo modelInfo = modelFile.toModelInfo();
@@ -398,11 +454,9 @@
@Test
public void modelFile_toModelInfos() {
ModelFile englishModelFile =
- new ModelFile(
- MODEL_TYPE, new File("/path/a"), 1, ImmutableList.of(Locale.ENGLISH), "en", false);
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 1, "en", /* isAsset= */ false);
ModelFile japaneseModelFile =
- new ModelFile(
- MODEL_TYPE, new File("/path/a"), 2, ImmutableList.of(Locale.JAPANESE), "ja", false);
+ new ModelFile(MODEL_TYPE, "/path/a", /* version= */ 2, "ja", /* isAsset= */ false);
ImmutableList<Optional<ModelInfo>> modelInfos =
ModelFileManager.ModelFile.toModelInfos(
@@ -417,64 +471,53 @@
}
@Test
- public void testFileSupplierImpl_updatedFileOnly() throws IOException {
- when(mockDeviceConfig.getBoolean(
- eq(TextClassifierSettings.NAMESPACE),
- eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
- anyBoolean()))
- .thenReturn(false);
- configUpdaterModelFile.createNewFile();
- File downloaderModelFile = new File(downloaderModelDir, "test0.model");
- downloaderModelFile.createNewFile();
- File model1 = new File(factoryModelDir, "test1.model");
- model1.createNewFile();
- File model2 = new File(factoryModelDir, "test2.model");
- model2.createNewFile();
- new File(factoryModelDir, "not_match_regex.model").createNewFile();
+ public void regularFileFullMatchLister() throws IOException {
+ File modelFile = new File(rootTestDir, "test.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile);
+ File wrongFile = new File(rootTestDir, "wrong.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), wrongFile);
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream().map(modelFile -> modelFile.getPath()).collect(Collectors.toList());
+ RegularFileFullMatchLister regularFileFullMatchLister =
+ new RegularFileFullMatchLister(MODEL_TYPE, modelFile, () -> true);
+ ImmutableList<ModelFile> listedModels = regularFileFullMatchLister.list(MODEL_TYPE);
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- configUpdaterModelFile.getAbsolutePath(),
- model1.getAbsolutePath(),
- model2.getAbsolutePath());
+ assertThat(listedModels).hasSize(1);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
}
@Test
- public void testFileSupplierImpl_includeDownloaderFile() throws IOException {
- when(mockDeviceConfig.getBoolean(
- eq(TextClassifierSettings.NAMESPACE),
- eq(TextClassifierSettings.MODEL_DOWNLOAD_MANAGER_ENABLED),
- anyBoolean()))
- .thenReturn(true);
- configUpdaterModelFile.createNewFile();
- File downloaderModelFile = new File(downloaderModelDir, "test0.model");
- downloaderModelFile.createNewFile();
- File factoryModelFile = new File(factoryModelDir, "test1.model");
- factoryModelFile.createNewFile();
+ public void regularFilePatternMatchLister() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
+ File modelFile2 = new File(rootTestDir, "annotator.fr.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile2);
+ File mismatchedModelFile = new File(rootTestDir, "actions.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), mismatchedModelFile);
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
- List<String> modelFilePaths =
- modelFiles.stream().map(ModelFile::getPath).collect(Collectors.toList());
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> true);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
- assertThat(modelFiles).hasSize(3);
- assertThat(modelFilePaths)
- .containsExactly(
- configUpdaterModelFile.getAbsolutePath(),
- downloaderModelFile.getAbsolutePath(),
- factoryModelFile.getAbsolutePath());
+ assertThat(listedModels).hasSize(2);
+ assertThat(listedModels.get(0).absolutePath).isEqualTo(modelFile1.getAbsolutePath());
+ assertThat(listedModels.get(0).isAsset).isFalse();
+ assertThat(listedModels.get(1).absolutePath).isEqualTo(modelFile2.getAbsolutePath());
+ assertThat(listedModels.get(1).isAsset).isFalse();
}
@Test
- public void testFileSupplierImpl_empty() {
- factoryModelDir.delete();
- List<ModelFileManager.ModelFile> modelFiles = modelFileSupplierImpl.get();
+ public void regularFilePatternMatchLister_disabled() throws IOException {
+ File modelFile1 = new File(rootTestDir, "annotator.en.model");
+ Files.copy(TestDataUtils.getTestAnnotatorModelFile(), modelFile1);
- assertThat(modelFiles).hasSize(0);
+ RegularFilePatternMatchLister regularFilePatternMatchLister =
+ new RegularFilePatternMatchLister(
+ MODEL_TYPE, rootTestDir, "annotator\\.(.*)\\.model", () -> false);
+ ImmutableList<ModelFile> listedModels = regularFilePatternMatchLister.list(MODEL_TYPE);
+
+ assertThat(listedModels).isEmpty();
}
private static void recursiveDelete(File f) {
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
index 1acdebf..7565a0b 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TestDataUtils.java
@@ -16,63 +16,44 @@
package com.android.textclassifier;
-import com.android.textclassifier.ModelFileManager.ModelFile;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import android.content.Context;
+import com.android.textclassifier.ModelFileManager.ModelType;
+import com.android.textclassifier.ModelFileManager.RegularFileFullMatchLister;
import com.google.common.collect.ImmutableList;
-import com.google.common.collect.ImmutableMap;
import java.io.File;
-import java.util.Locale;
-import java.util.function.Supplier;
/** Utils to access test data files. */
public final class TestDataUtils {
- private static final ImmutableMap<String, Supplier<ImmutableList<ModelFile>>>
- MODEL_FILES_SUPPLIER =
- new ImmutableMap.Builder<String, Supplier<ImmutableList<ModelFile>>>()
- .put(
- ModelType.ANNOTATOR,
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.ANNOTATOR,
- new File(
- TestDataUtils.getTestDataFolder(), "testdata/annotator.model"),
- 711,
- ImmutableList.of(Locale.ENGLISH),
- "en",
- false)))
- .put(
- ModelType.ACTIONS_SUGGESTIONS,
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.ACTIONS_SUGGESTIONS,
- new File(TestDataUtils.getTestDataFolder(), "testdata/actions.model"),
- 104,
- ImmutableList.of(Locale.ENGLISH),
- "en",
- false)))
- .put(
- ModelType.LANG_ID,
- () ->
- ImmutableList.of(
- new ModelFile(
- ModelType.LANG_ID,
- new File(TestDataUtils.getTestDataFolder(), "testdata/langid.model"),
- 1,
- ImmutableList.of(),
- "*",
- true)))
- .build();
+ private static final String TEST_ANNOTATOR_MODEL_PATH = "testdata/annotator.model";
+ private static final String TEST_ACTIONS_MODEL_PATH = "testdata/actions.model";
+ private static final String TEST_LANGID_MODEL_PATH = "testdata/langid.model";
/** Returns the root folder that contains the test data. */
public static File getTestDataFolder() {
return new File("/data/local/tmp/TextClassifierServiceTest/");
}
- public static ModelFileManager createModelFileManagerForTesting() {
+ public static File getTestAnnotatorModelFile() {
+ return new File(getTestDataFolder(), TEST_ANNOTATOR_MODEL_PATH);
+ }
+
+ public static File getTestActionsModelFile() {
+ return new File(getTestDataFolder(), TEST_ACTIONS_MODEL_PATH);
+ }
+
+ public static File getLangIdModelFile() {
+ return new File(getTestDataFolder(), TEST_LANGID_MODEL_PATH);
+ }
+
+ public static ModelFileManager createModelFileManagerForTesting(Context context) {
return new ModelFileManager(
- /* downloadModelDir= */ TestDataUtils.getTestDataFolder(), MODEL_FILES_SUPPLIER);
+ context,
+ ImmutableList.of(
+ new RegularFileFullMatchLister(
+ ModelType.ANNOTATOR, getTestAnnotatorModelFile(), () -> true),
+ new RegularFileFullMatchLister(
+ ModelType.ACTIONS_SUGGESTIONS, getTestActionsModelFile(), () -> true),
+ new RegularFileFullMatchLister(ModelType.LANG_ID, getLangIdModelFile(), () -> true)));
}
private TestDataUtils() {}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
new file mode 100644
index 0000000..0031368
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierApiTest.java
@@ -0,0 +1,197 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import android.app.UiAutomation;
+import android.icu.util.ULocale;
+import android.provider.DeviceConfig;
+import android.view.textclassifier.ConversationAction;
+import android.view.textclassifier.ConversationActions;
+import android.view.textclassifier.TextClassification;
+import android.view.textclassifier.TextClassificationManager;
+import android.view.textclassifier.TextClassifier;
+import android.view.textclassifier.TextLanguage;
+import android.view.textclassifier.TextLinks;
+import android.view.textclassifier.TextLinks.TextLink;
+import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
+import androidx.test.ext.junit.runners.AndroidJUnit4;
+import androidx.test.filters.SmallTest;
+import androidx.test.platform.app.InstrumentationRegistry;
+import com.google.common.collect.ImmutableList;
+import java.util.ArrayList;
+import java.util.List;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExternalResource;
+import org.junit.runner.RunWith;
+
+/**
+ * End-to-end tests for the {@link TextClassifier} APIs. Unlike {@link TextClassifierImplTest}.
+ *
+ * <p>Unlike {@link TextClassifierImplTest}, we are trying to run the tests in a environment that is
+ * closer to the production environment. For example, we are not injecting the model files.
+ */
+@SmallTest
+@RunWith(AndroidJUnit4.class)
+public class TextClassifierApiTest {
+
+ private TextClassifier textClassifier;
+
+ @Rule
+ public final ExtServicesTextClassifierRule extServicesTextClassifierRule =
+ new ExtServicesTextClassifierRule();
+
+ @Before
+ public void setup() {
+ textClassifier = extServicesTextClassifierRule.getTextClassifier();
+ }
+
+ @Test
+ public void suggestSelection() {
+ String text = "Visit http://www.android.com for more information";
+ String selected = "http";
+ String suggested = "http://www.android.com";
+ int startIndex = text.indexOf(selected);
+ int endIndex = startIndex + selected.length();
+ int smartStartIndex = text.indexOf(suggested);
+ int smartEndIndex = smartStartIndex + suggested.length();
+
+ TextSelection.Request request =
+ new TextSelection.Request.Builder(text, startIndex, endIndex).build();
+
+ TextSelection selection = textClassifier.suggestSelection(request);
+ assertThat(selection.getEntityCount()).isGreaterThan(0);
+ assertThat(selection.getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(selection.getSelectionStartIndex()).isEqualTo(smartStartIndex);
+ assertThat(selection.getSelectionEndIndex()).isEqualTo(smartEndIndex);
+ }
+
+ @Test
+ public void classifyText() {
+ String text = "Contact me at droid@android.com";
+ String classifiedText = "droid@android.com";
+ int startIndex = text.indexOf(classifiedText);
+ int endIndex = startIndex + classifiedText.length();
+ TextClassification.Request request =
+ new TextClassification.Request.Builder(text, startIndex, endIndex).build();
+
+ TextClassification classification = textClassifier.classifyText(request);
+ assertThat(classification.getEntityCount()).isGreaterThan(0);
+ assertThat(classification.getEntity(0)).isEqualTo(TextClassifier.TYPE_EMAIL);
+ assertThat(classification.getText()).isEqualTo(classifiedText);
+ assertThat(classification.getActions().size()).isGreaterThan(1);
+ }
+
+ @Test
+ public void generateLinks() {
+ String text = "Check this out, http://www.android.com";
+
+ TextLinks.Request request = new TextLinks.Request.Builder(text).build();
+
+ TextLinks textLinks = textClassifier.generateLinks(request);
+
+ List<TextLink> links = new ArrayList<>(textLinks.getLinks());
+ assertThat(textLinks.getText().toString()).isEqualTo(text);
+ assertThat(links).hasSize(1);
+ assertThat(links.get(0).getEntityCount()).isGreaterThan(0);
+ assertThat(links.get(0).getEntity(0)).isEqualTo(TextClassifier.TYPE_URL);
+ assertThat(links.get(0).getConfidenceScore(TextClassifier.TYPE_URL)).isGreaterThan(0);
+ }
+
+ @Test
+ public void detectedLanguage() {
+ String text = "朝、ピカチュウ";
+ TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
+
+ TextLanguage textLanguage = textClassifier.detectLanguage(request);
+
+ assertThat(textLanguage.getLocaleHypothesisCount()).isGreaterThan(0);
+ assertThat(textLanguage.getLocale(0).getLanguage()).isEqualTo("ja");
+ assertThat(textLanguage.getConfidenceScore(ULocale.JAPANESE)).isGreaterThan(0);
+ }
+
+ @Test
+ public void suggestConversationActions() {
+ ConversationActions.Message message =
+ new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
+ .setText("Check this out: https://www.android.com")
+ .build();
+ ConversationActions.Request request =
+ new ConversationActions.Request.Builder(ImmutableList.of(message)).build();
+
+ ConversationActions conversationActions = textClassifier.suggestConversationActions(request);
+
+ assertThat(conversationActions.getConversationActions()).hasSize(1);
+ ConversationAction conversationAction = conversationActions.getConversationActions().get(0);
+ assertThat(conversationAction.getType()).isEqualTo(ConversationAction.TYPE_OPEN_URL);
+ assertThat(conversationAction.getAction()).isNotNull();
+ }
+
+ /** A rule that manages a text classifier that is backed by the ExtServices. */
+ private static class ExtServicesTextClassifierRule extends ExternalResource {
+ private static final String CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE =
+ "textclassifier_service_package_override";
+
+ private String textClassifierServiceOverrideFlagOldValue;
+
+ @Override
+ protected void before() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ textClassifierServiceOverrideFlagOldValue =
+ DeviceConfig.getString(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ null);
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ "com.google.android.ext.services",
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ @Override
+ protected void after() {
+ UiAutomation uiAutomation = InstrumentationRegistry.getInstrumentation().getUiAutomation();
+ try {
+ uiAutomation.adoptShellPermissionIdentity();
+ DeviceConfig.setProperty(
+ DeviceConfig.NAMESPACE_TEXTCLASSIFIER,
+ CONFIG_TEXT_CLASSIFIER_SERVICE_PACKAGE_OVERRIDE,
+ textClassifierServiceOverrideFlagOldValue,
+ /* makeDefault= */ false);
+ } finally {
+ uiAutomation.dropShellPermissionIdentity();
+ }
+ }
+
+ public TextClassifier getTextClassifier() {
+ TextClassificationManager textClassificationManager =
+ ApplicationProvider.getApplicationContext()
+ .getSystemService(TextClassificationManager.class);
+ return textClassificationManager.getTextClassifier();
+ }
+ }
+}
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
index 5007e2a..06ec640 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierImplTest.java
@@ -38,10 +38,12 @@
import android.view.textclassifier.TextLanguage;
import android.view.textclassifier.TextLinks;
import android.view.textclassifier.TextSelection;
+import androidx.test.core.app.ApplicationProvider;
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import com.android.textclassifier.testing.FakeContextBuilder;
import com.google.common.collect.ImmutableList;
+import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -63,7 +65,7 @@
private TextClassifierImpl classifier;
private final ModelFileManager modelFileManager =
- TestDataUtils.createModelFileManagerForTesting();
+ TestDataUtils.createModelFileManagerForTesting(ApplicationProvider.getApplicationContext());
@Before
public void setup() {
@@ -77,7 +79,7 @@
}
@Test
- public void testSuggestSelection() {
+ public void testSuggestSelection() throws IOException {
String text = "Contact me at droid@android.com";
String selected = "droid";
String suggested = "droid@android.com";
@@ -96,7 +98,7 @@
}
@Test
- public void testSuggestSelection_url() {
+ public void testSuggestSelection_url() throws IOException {
String text = "Visit http://www.android.com for more information";
String selected = "http";
String suggested = "http://www.android.com";
@@ -114,7 +116,7 @@
}
@Test
- public void testSmartSelection_withEmoji() {
+ public void testSmartSelection_withEmoji() throws IOException {
String text = "\uD83D\uDE02 Hello.";
String selected = "Hello";
int startIndex = text.indexOf(selected);
@@ -129,7 +131,7 @@
}
@Test
- public void testClassifyText() {
+ public void testClassifyText() throws IOException {
String text = "Contact me at droid@android.com";
String classifiedText = "droid@android.com";
int startIndex = text.indexOf(classifiedText);
@@ -144,7 +146,7 @@
}
@Test
- public void testClassifyText_url() {
+ public void testClassifyText_url() throws IOException {
String text = "Visit www.android.com for more information";
String classifiedText = "www.android.com";
int startIndex = text.indexOf(classifiedText);
@@ -160,7 +162,7 @@
}
@Test
- public void testClassifyText_address() {
+ public void testClassifyText_address() throws IOException {
String text = "Brandschenkestrasse 110, Zürich, Switzerland";
TextClassification.Request request =
new TextClassification.Request.Builder(text, 0, text.length())
@@ -172,7 +174,7 @@
}
@Test
- public void testClassifyText_url_inCaps() {
+ public void testClassifyText_url_inCaps() throws IOException {
String text = "Visit HTTP://ANDROID.COM for more information";
String classifiedText = "HTTP://ANDROID.COM";
int startIndex = text.indexOf(classifiedText);
@@ -188,7 +190,7 @@
}
@Test
- public void testClassifyText_date() {
+ public void testClassifyText_date() throws IOException {
String text = "Let's meet on January 9, 2018.";
String classifiedText = "January 9, 2018";
int startIndex = text.indexOf(classifiedText);
@@ -209,7 +211,7 @@
}
@Test
- public void testClassifyText_datetime() {
+ public void testClassifyText_datetime() throws IOException {
String text = "Let's meet 2018/01/01 10:30:20.";
String classifiedText = "2018/01/01 10:30:20";
int startIndex = text.indexOf(classifiedText);
@@ -224,7 +226,7 @@
}
@Test
- public void testClassifyText_foreignText() {
+ public void testClassifyText_foreignText() throws IOException {
LocaleList originalLocales = LocaleList.getDefault();
LocaleList.setDefault(LocaleList.forLanguageTags("en"));
String japaneseText = "これは日本語のテキストです";
@@ -253,7 +255,7 @@
}
@Test
- public void testGenerateLinks_phone() {
+ public void testGenerateLinks_phone() throws IOException {
String text = "The number is +12122537077. See you tonight!";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
assertThat(
@@ -262,7 +264,7 @@
}
@Test
- public void testGenerateLinks_exclude() {
+ public void testGenerateLinks_exclude() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = ImmutableList.of();
@@ -278,7 +280,7 @@
}
@Test
- public void testGenerateLinks_explicit_address() {
+ public void testGenerateLinks_explicit_address() throws IOException {
String text = "The address is 1600 Amphitheater Parkway, Mountain View, CA. See you!";
List<String> explicit = Arrays.asList(TextClassifier.TYPE_ADDRESS);
TextLinks.Request request =
@@ -293,7 +295,7 @@
}
@Test
- public void testGenerateLinks_exclude_override() {
+ public void testGenerateLinks_exclude_override() throws IOException {
String text = "You want apple@banana.com. See you tonight!";
List<String> hints = ImmutableList.of();
List<String> included = Arrays.asList(TextClassifier.TYPE_EMAIL);
@@ -309,7 +311,7 @@
}
@Test
- public void testGenerateLinks_maxLength() {
+ public void testGenerateLinks_maxLength() throws IOException {
char[] manySpaces = new char[classifier.getMaxGenerateLinksTextLength()];
Arrays.fill(manySpaces, ' ');
TextLinks.Request request = new TextLinks.Request.Builder(new String(manySpaces)).build();
@@ -318,7 +320,7 @@
}
@Test
- public void testApplyLinks_unsupportedCharacter() {
+ public void testApplyLinks_unsupportedCharacter() throws IOException {
Spannable url = new SpannableString("\u202Emoc.diordna.com");
TextLinks.Request request = new TextLinks.Request.Builder(url).build();
assertEquals(
@@ -335,7 +337,7 @@
}
@Test
- public void testGenerateLinks_entityData() {
+ public void testGenerateLinks_entityData() throws IOException {
String text = "The number is +12122537077.";
Bundle extras = new Bundle();
ExtrasUtils.putIsSerializedEntityDataEnabled(extras, true);
@@ -352,7 +354,7 @@
}
@Test
- public void testGenerateLinks_entityData_disabled() {
+ public void testGenerateLinks_entityData_disabled() throws IOException {
String text = "The number is +12122537077.";
TextLinks.Request request = new TextLinks.Request.Builder(text).build();
@@ -365,7 +367,7 @@
}
@Test
- public void testDetectLanguage() {
+ public void testDetectLanguage() throws IOException {
String text = "This is English text";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(request);
@@ -373,7 +375,7 @@
}
@Test
- public void testDetectLanguage_japanese() {
+ public void testDetectLanguage_japanese() throws IOException {
String text = "これは日本語のテキストです";
TextLanguage.Request request = new TextLanguage.Request.Builder(text).build();
TextLanguage textLanguage = classifier.detectLanguage(request);
@@ -381,7 +383,7 @@
}
@Test
- public void testSuggestConversationActions_textReplyOnly_maxOne() {
+ public void testSuggestConversationActions_textReplyOnly_maxOne() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -405,7 +407,7 @@
}
@Test
- public void testSuggestConversationActions_textReplyOnly_noMax() {
+ public void testSuggestConversationActions_textReplyOnly_noMax() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Where are you?")
@@ -428,7 +430,7 @@
}
@Test
- public void testSuggestConversationActions_openUrl() {
+ public void testSuggestConversationActions_openUrl() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Check this out: https://www.android.com")
@@ -455,7 +457,7 @@
}
@Test
- public void testSuggestConversationActions_copy() {
+ public void testSuggestConversationActions_copy() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("Authentication code: 12345")
@@ -483,7 +485,7 @@
}
@Test
- public void testSuggestConversationActions_deduplicate() {
+ public void testSuggestConversationActions_deduplicate() throws IOException {
ConversationActions.Message message =
new ConversationActions.Message.Builder(ConversationActions.Message.PERSON_USER_OTHERS)
.setText("a@android.com b@android.com")
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
index c0a823e..e1e7982 100644
--- a/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
+++ b/java/tests/instrumentation/src/com/android/textclassifier/TextClassifierSettingsTest.java
@@ -22,7 +22,7 @@
import androidx.test.ext.junit.runners.AndroidJUnit4;
import androidx.test.filters.SmallTest;
import androidx.test.platform.app.InstrumentationRegistry;
-import com.android.textclassifier.ModelFileManager.ModelFile.ModelType;
+import com.android.textclassifier.ModelFileManager.ModelType;
import java.util.function.Consumer;
import org.junit.After;
import org.junit.Before;
diff --git a/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
new file mode 100644
index 0000000..ec1405b
--- /dev/null
+++ b/java/tests/instrumentation/src/com/android/textclassifier/testing/SetDefaultLocalesRule.java
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.textclassifier.testing;
+
+import android.os.LocaleList;
+import org.junit.rules.ExternalResource;
+
+public class SetDefaultLocalesRule extends ExternalResource {
+
+ private LocaleList originalValue;
+
+ @Override
+ protected void before() throws Throwable {
+ super.before();
+ originalValue = LocaleList.getDefault();
+ }
+
+ public void set(LocaleList newValue) {
+ LocaleList.setDefault(newValue);
+ }
+
+ @Override
+ protected void after() {
+ super.after();
+ LocaleList.setDefault(originalValue);
+ }
+}
diff --git a/jni/com/google/android/textclassifier/AnnotatorModel.java b/jni/com/google/android/textclassifier/AnnotatorModel.java
index 3382acd..47a369e 100644
--- a/jni/com/google/android/textclassifier/AnnotatorModel.java
+++ b/jni/com/google/android/textclassifier/AnnotatorModel.java
@@ -342,6 +342,8 @@
@Nullable private final String contactNickname;
@Nullable private final String contactEmailAddress;
@Nullable private final String contactPhoneNumber;
+ @Nullable private final String contactAccountType;
+ @Nullable private final String contactAccountName;
@Nullable private final String contactId;
@Nullable private final String appName;
@Nullable private final String appPackageName;
@@ -363,6 +365,8 @@
@Nullable String contactNickname,
@Nullable String contactEmailAddress,
@Nullable String contactPhoneNumber,
+ @Nullable String contactAccountType,
+ @Nullable String contactAccountName,
@Nullable String contactId,
@Nullable String appName,
@Nullable String appPackageName,
@@ -382,6 +386,8 @@
this.contactNickname = contactNickname;
this.contactEmailAddress = contactEmailAddress;
this.contactPhoneNumber = contactPhoneNumber;
+ this.contactAccountType = contactAccountType;
+ this.contactAccountName = contactAccountName;
this.contactId = contactId;
this.appName = appName;
this.appPackageName = appPackageName;
@@ -444,6 +450,16 @@
}
@Nullable
+ public String getContactAccountType() {
+ return contactAccountType;
+ }
+
+ @Nullable
+ public String getContactAccountName() {
+ return contactAccountName;
+ }
+
+ @Nullable
public String getContactId() {
return contactId;
}
@@ -550,22 +566,40 @@
public InputFragment(String text) {
this.text = text;
this.datetimeOptionsNullable = null;
+ this.boundingBoxTop = 0;
+ this.boundingBoxHeight = 0;
}
- public InputFragment(String text, DatetimeOptions datetimeOptions) {
+ public InputFragment(
+ String text,
+ DatetimeOptions datetimeOptions,
+ float boundingBoxTop,
+ float boundingBoxHeight) {
this.text = text;
this.datetimeOptionsNullable = datetimeOptions;
+ this.boundingBoxTop = boundingBoxTop;
+ this.boundingBoxHeight = boundingBoxHeight;
}
private final String text;
// The DatetimeOptions can't be Optional because the _api16 build of the TCLib SDK does not
// support java.util.Optional.
private final DatetimeOptions datetimeOptionsNullable;
+ private final float boundingBoxTop;
+ private final float boundingBoxHeight;
public String getText() {
return text;
}
+ public float getBoundingBoxTop() {
+ return boundingBoxTop;
+ }
+
+ public float getBoundingBoxHeight() {
+ return boundingBoxHeight;
+ }
+
public boolean hasDatetimeOptions() {
return datetimeOptionsNullable != null;
}
diff --git a/jni/com/google/android/textclassifier/LangIdModel.java b/jni/com/google/android/textclassifier/LangIdModel.java
index 15f352b..890c9b0 100644
--- a/jni/com/google/android/textclassifier/LangIdModel.java
+++ b/jni/com/google/android/textclassifier/LangIdModel.java
@@ -16,6 +16,7 @@
package com.google.android.textclassifier;
+import android.content.res.AssetFileDescriptor;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -48,6 +49,21 @@
}
}
+ /**
+ * Creates a new instance of LangId predictor, using the provided model image, given as an {@link
+ * AssetFileDescriptor}.
+ */
+ public LangIdModel(AssetFileDescriptor assetFileDescriptor) {
+ modelPtr =
+ nativeNewWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ if (modelPtr == 0L) {
+ throw new IllegalArgumentException("Couldn't initialize LangId from asset file descriptor.");
+ }
+ }
+
/** Creates a new instance of LangId predictor, using the provided model image. */
public LangIdModel(int fd, long offset, long size) {
modelPtr = nativeNewWithOffset(fd, offset, size);
@@ -103,14 +119,22 @@
return nativeGetVersion(modelPtr);
}
- public float getLangIdThreshold() {
- return nativeGetLangIdThreshold(modelPtr);
- }
-
public static int getVersion(int fd) {
return nativeGetVersionFromFd(fd);
}
+ /** Returns the version of the model. */
+ public static int getVersion(AssetFileDescriptor assetFileDescriptor) {
+ return nativeGetVersionWithOffset(
+ assetFileDescriptor.getParcelFileDescriptor().getFd(),
+ assetFileDescriptor.getStartOffset(),
+ assetFileDescriptor.getLength());
+ }
+
+ public float getLangIdThreshold() {
+ return nativeGetLangIdThreshold(modelPtr);
+ }
+
/** Retrieves the pointer to the native object. */
long getNativePointer() {
return modelPtr;
@@ -153,4 +177,6 @@
private native float nativeGetLangIdNoiseThreshold(long nativePtr);
private native int nativeGetMinTextSizeInBytes(long nativePtr);
+
+ private static native int nativeGetVersionWithOffset(int fd, long offset, long size);
}
diff --git a/native/Android.bp b/native/Android.bp
index 1881f67..1b49c11 100644
--- a/native/Android.bp
+++ b/native/Android.bp
@@ -215,6 +215,7 @@
defaults: ["libtextclassifier_defaults"],
srcs: [
":libtextclassifier_java_test_sources",
+ "annotator/datetime/testing/*.cc",
"actions/test-utils.cc",
"utils/testing/annotator.cc",
"utils/testing/logging_event_listener.cc",
diff --git a/native/JavaTests.bp b/native/JavaTests.bp
index c95ae09..1c5099d 100644
--- a/native/JavaTests.bp
+++ b/native/JavaTests.bp
@@ -19,8 +19,7 @@
srcs: [
"actions/actions-suggestions_test.cc",
"actions/grammar-actions_test.cc",
- "annotator/datetime/datetime-grounder_test.cc",
- "annotator/datetime/parser_test.cc",
+ "annotator/datetime/regex-parser_test.cc",
"utils/grammar/parsing/lexer_test.cc",
"utils/regex-match_test.cc",
"utils/calendar/calendar_test.cc",
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index b0dda55..a9edde9 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -1345,7 +1345,9 @@
TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size();
return {};
}
- if (!IsValidUTF8(message.text.data(), message.text.size())) {
+
+ if (!unilib_->IsValidUtf8(UTF8ToUnicodeText(
+ message.text.data(), message.text.size(), /*do_copy=*/false))) {
TC3_LOG(ERROR) << "Not valid utf8 provided.";
return response;
}
diff --git a/native/actions/actions-suggestions_test.cc b/native/actions/actions-suggestions_test.cc
index 1bb029b..55aa852 100644
--- a/native/actions/actions-suggestions_test.cc
+++ b/native/actions/actions-suggestions_test.cc
@@ -97,6 +97,20 @@
EXPECT_THAT(response.actions, IsEmpty());
}
+TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
+ std::unique_ptr<ActionsSuggestions> actions_suggestions =
+ LoadTestModel(kModelFileName);
+
+ const ActionsSuggestionsResponse response =
+ actions_suggestions->SuggestActions(
+ {{{/*user_id=*/1,
+ "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
+ /*reference_time_ms_utc=*/0,
+ /*reference_timezone=*/"Europe/Zurich",
+ /*annotations=*/{}, /*locales=*/"en"}}});
+ EXPECT_THAT(response.actions, IsEmpty());
+}
+
TEST_F(ActionsSuggestionsTest, SuggestsActions) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
diff --git a/native/actions/actions_model.fbs b/native/actions/actions_model.fbs
index 763a1ed..4894ed0 100755
--- a/native/actions/actions_model.fbs
+++ b/native/actions/actions_model.fbs
@@ -14,16 +14,16 @@
// limitations under the License.
//
-include "annotator/model.fbs";
-include "utils/flatbuffers/flatbuffers.fbs";
-include "utils/intents/intent-config.fbs";
+include "utils/codepoint-range.fbs";
include "utils/normalization.fbs";
-include "utils/tokenizer.fbs";
+include "utils/intents/intent-config.fbs";
+include "annotator/model.fbs";
include "utils/resources.fbs";
include "utils/zlib/buffer.fbs";
-include "utils/codepoint-range.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
include "actions/actions-entity-data.fbs";
include "utils/grammar/rules.fbs";
+include "utils/tokenizer.fbs";
file_identifier "TC3A";
diff --git a/native/actions/test_data/actions_suggestions_grammar_test.model b/native/actions/test_data/actions_suggestions_grammar_test.model
index 3cc85dc..daca1ff 100644
--- a/native/actions/test_data/actions_suggestions_grammar_test.model
+++ b/native/actions/test_data/actions_suggestions_grammar_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.model b/native/actions/test_data/actions_suggestions_test.model
index 0b2fda6..488fc1c 100644
--- a/native/actions/test_data/actions_suggestions_test.model
+++ b/native/actions/test_data/actions_suggestions_test.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
index 8a389c7..247d164 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_9heads.model
Binary files differ
diff --git a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
index 926bd22..5906b1a 100644
--- a/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
+++ b/native/actions/test_data/actions_suggestions_test.multi_task_sr_nudge_signal_v0.model
Binary files differ
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 52819af..2635820 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -27,6 +27,7 @@
#include <vector>
#include "annotator/collections.h"
+#include "annotator/datetime/regex-parser.h"
#include "annotator/flatbuffer-utils.h"
#include "annotator/knowledge/knowledge-engine-types.h"
#include "annotator/model_generated.h"
@@ -34,7 +35,9 @@
#include "utils/base/logging.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
#include "utils/checksum.h"
+#include "utils/i18n/locale-list.h"
#include "utils/i18n/locale.h"
#include "utils/math/softmax.h"
#include "utils/normalization.h"
@@ -106,12 +109,8 @@
}
// Returns whether the provided input is valid:
-// * Valid utf8 text.
// * Sane span indices.
bool IsValidSpanInput(const UnicodeText& context, const CodepointSpan& span) {
- if (!context.is_valid()) {
- return false;
- }
return (span.first >= 0 && span.first < span.second &&
span.second <= context.size_codepoints());
}
@@ -418,7 +417,7 @@
}
if (model_->datetime_model()) {
- datetime_parser_ = DatetimeParser::Instance(
+ datetime_parser_ = RegexDatetimeParser::Instance(
model_->datetime_model(), unilib_, calendarlib_, decompressor.get());
if (!datetime_parser_) {
TC3_LOG(ERROR) << "Could not initialize datetime parser.";
@@ -849,6 +848,11 @@
const UnicodeText context_unicode = UTF8ToUnicodeText(context,
/*do_copy=*/false);
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return original_click_indices;
+ }
+
if (!IsValidSpanInput(context_unicode, click_indices)) {
TC3_VLOG(1)
<< "Trying to run SuggestSelection with invalid input, indices: "
@@ -1671,20 +1675,20 @@
UTF8ToUnicodeText(context, /*do_copy=*/false)
.UTF8Substring(selection_indices.first, selection_indices.second);
- std::vector<DatetimeParseResultSpan> datetime_spans;
-
- if (datetime_parser_) {
- if (!datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
- options.reference_timezone, options.locales,
- ModeFlag_CLASSIFICATION,
- options.annotation_usecase,
- /*anchor_start_end=*/true, &datetime_spans)) {
- TC3_LOG(ERROR) << "Error during parsing datetime.";
- return false;
- }
+ LocaleList locale_list = LocaleList::ParseFrom(options.locales);
+ StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
+ datetime_parser_->Parse(selection_text, options.reference_time_ms_utc,
+ options.reference_timezone, locale_list,
+ ModeFlag_CLASSIFICATION,
+ options.annotation_usecase,
+ /*anchor_start_end=*/true);
+ if (!result_status.ok()) {
+ TC3_LOG(ERROR) << "Error during parsing datetime.";
+ return false;
}
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ for (const DatetimeParseResultSpan& datetime_span :
+ result_status.ValueOrDie()) {
// Only consider the result valid if the selection and extracted datetime
// spans exactly match.
if (CodepointSpan(datetime_span.span.first + selection_indices.first,
@@ -1740,8 +1744,15 @@
return {};
}
- if (!IsValidSpanInput(UTF8ToUnicodeText(context, /*do_copy=*/false),
- selection_indices)) {
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return {};
+ }
+
+ if (!IsValidSpanInput(context_unicode, selection_indices)) {
TC3_VLOG(1) << "Trying to run ClassifyText with invalid input: "
<< selection_indices.first << " " << selection_indices.second;
return {};
@@ -1815,9 +1826,6 @@
candidates.back().source = AnnotatedSpan::Source::DATETIME;
}
- const UnicodeText context_unicode =
- UTF8ToUnicodeText(context, /*do_copy=*/false);
-
// Try the number annotator.
// TODO(b/126579108): Propagate error status.
ClassificationResult number_annotator_result;
@@ -2105,10 +2113,6 @@
const UnicodeText context_unicode =
UTF8ToUnicodeText(context, /*do_copy=*/false);
- if (!context_unicode.is_valid()) {
- return Status(StatusCode::INVALID_ARGUMENT,
- "Context string isn't valid UTF8.");
- }
std::vector<Locale> detected_text_language_tags;
if (!ParseLocales(options.detected_text_language_tags,
@@ -2328,15 +2332,21 @@
std::vector<std::string> text_to_annotate;
text_to_annotate.reserve(string_fragments.size());
+ std::vector<FragmentMetadata> fragment_metadata;
+ fragment_metadata.reserve(string_fragments.size());
for (const auto& string_fragment : string_fragments) {
text_to_annotate.push_back(string_fragment.text);
+ fragment_metadata.push_back(
+ {.relative_bounding_box_top = string_fragment.bounding_box_top,
+ .relative_bounding_box_height = string_fragment.bounding_box_height});
}
// KnowledgeEngine is special, because it supports annotation of multiple
// fragments at once.
if (knowledge_engine_ &&
!knowledge_engine_
- ->ChunkMultipleSpans(text_to_annotate, options.annotation_usecase,
+ ->ChunkMultipleSpans(text_to_annotate, fragment_metadata,
+ options.annotation_usecase,
options.location_context, options.permissions,
options.annotate_mode, &annotation_candidates)
.ok()) {
@@ -2394,6 +2404,13 @@
return {};
}
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ if (!unilib_->IsValidUtf8(context_unicode)) {
+ TC3_LOG(ERROR) << "Rejecting input, invalid UTF8.";
+ return {};
+ }
+
std::vector<InputFragment> string_fragments;
string_fragments.push_back({.text = context});
StatusOr<Annotations> annotations =
@@ -3066,18 +3083,21 @@
AnnotationUsecase annotation_usecase,
bool is_serialized_entity_data_enabled,
std::vector<AnnotatedSpan>* result) const {
- std::vector<DatetimeParseResultSpan> datetime_spans;
-
- if (datetime_parser_) {
- if (!datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
- reference_timezone, locales, mode,
- annotation_usecase,
- /*anchor_start_end=*/false, &datetime_spans)) {
- return false;
- }
+ if (!datetime_parser_) {
+ return true;
+ }
+ LocaleList locale_list = LocaleList::ParseFrom(locales);
+ StatusOr<std::vector<DatetimeParseResultSpan>> result_status =
+ datetime_parser_->Parse(context_unicode, reference_time_ms_utc,
+ reference_timezone, locale_list, mode,
+ annotation_usecase,
+ /*anchor_start_end=*/false);
+ if (!result_status.ok()) {
+ return false;
}
- for (const DatetimeParseResultSpan& datetime_span : datetime_spans) {
+ for (const DatetimeParseResultSpan& datetime_span :
+ result_status.ValueOrDie()) {
AnnotatedSpan annotated_span;
annotated_span.span = datetime_span.span;
for (const DatetimeParseResult& parse_result : datetime_span.data) {
diff --git a/native/annotator/annotator.h b/native/annotator/annotator.h
index fbf3777..5397f56 100644
--- a/native/annotator/annotator.h
+++ b/native/annotator/annotator.h
@@ -45,6 +45,7 @@
#include "annotator/zlib-utils.h"
#include "utils/base/status.h"
#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/mutable.h"
#include "utils/i18n/locale.h"
diff --git a/native/annotator/annotator_jni.cc b/native/annotator/annotator_jni.cc
index 8d5ad33..7f095f9 100644
--- a/native/annotator/annotator_jni.cc
+++ b/native/annotator/annotator_jni.cc
@@ -205,6 +205,22 @@
env, classification_result.contact_phone_number.c_str()));
}
+ ScopedLocalRef<jstring> contact_account_type;
+ if (!classification_result.contact_account_type.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_account_type,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_account_type.c_str()));
+ }
+
+ ScopedLocalRef<jstring> contact_account_name;
+ if (!classification_result.contact_account_name.empty()) {
+ TC3_ASSIGN_OR_RETURN(
+ contact_account_name,
+ JniHelper::NewStringUTF(
+ env, classification_result.contact_account_name.c_str()));
+ }
+
ScopedLocalRef<jstring> contact_id;
if (!classification_result.contact_id.empty()) {
TC3_ASSIGN_OR_RETURN(
@@ -275,7 +291,8 @@
row_datetime_parse.get(), serialized_knowledge_result.get(),
contact_name.get(), contact_given_name.get(), contact_family_name.get(),
contact_nickname.get(), contact_email_address.get(),
- contact_phone_number.get(), contact_id.get(), app_name.get(),
+ contact_phone_number.get(), contact_account_type.get(),
+ contact_account_name.get(), contact_id.get(), app_name.get(),
app_package_name.get(), extras.get(), serialized_entity_data.get(),
remote_action_templates_result.get(), classification_result.duration_ms,
classification_result.numeric_value,
@@ -304,13 +321,23 @@
JniHelper::GetMethodID(
env, result_class.get(), "<init>",
"(Ljava/lang/String;FL" TC3_PACKAGE_PATH TC3_ANNOTATOR_CLASS_NAME_STR
- "$DatetimeResult;[BLjava/lang/String;Ljava/lang/String;Ljava/lang/"
- "String;"
- "Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/"
- "String;"
- "Ljava/lang/String;Ljava/lang/String;[L" TC3_PACKAGE_PATH
- "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";[B[L" TC3_PACKAGE_PATH
- "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";JJD)V"));
+ "$DatetimeResult;"
+ "[B"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "Ljava/lang/String;"
+ "[L" TC3_PACKAGE_PATH "" TC3_NAMED_VARIANT_CLASS_NAME_STR ";"
+ "[B"
+ "[L" TC3_PACKAGE_PATH "" TC3_REMOTE_ACTION_TEMPLATE_CLASS_NAME_STR ";"
+ "JJD)V"));
TC3_ASSIGN_OR_RETURN(const jmethodID datetime_parse_class_constructor,
JniHelper::GetMethodID(env, datetime_parse_class.get(),
"<init>", "(JI)V"));
diff --git a/native/annotator/annotator_jni_common.cc b/native/annotator/annotator_jni_common.cc
index 1d64e67..a6f636f 100644
--- a/native/annotator/annotator_jni_common.cc
+++ b/native/annotator/annotator_jni_common.cc
@@ -427,6 +427,24 @@
.reference_timezone = reference_timezone};
}
+ // .getBoundingBoxHeight()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_bounding_box_height,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getBoundingBoxHeight", "()F"));
+ TC3_ASSIGN_OR_RETURN(
+ float bounding_box_height,
+ JniHelper::CallFloatMethod(env, jfragment, get_bounding_box_height));
+
+ fragment.bounding_box_height = bounding_box_height;
+
+ // .getBoundingBoxTop()
+ TC3_ASSIGN_OR_RETURN(jmethodID get_bounding_box_top,
+ JniHelper::GetMethodID(env, fragment_class.get(),
+ "getBoundingBoxTop", "()F"));
+ TC3_ASSIGN_OR_RETURN(
+ float bounding_box_top,
+ JniHelper::CallFloatMethod(env, jfragment, get_bounding_box_top));
+ fragment.bounding_box_top = bounding_box_top;
return fragment;
}
} // namespace libtextclassifier3
diff --git a/native/annotator/annotator_test-include.cc b/native/annotator/annotator_test-include.cc
index da77acd..3ed91e1 100644
--- a/native/annotator/annotator_test-include.cc
+++ b/native/annotator/annotator_test-include.cc
@@ -2955,5 +2955,23 @@
/*duration_ms=*/3 * 60 * 60 * 1000)));
}
+TEST_F(AnnotatorTest, DoesntProcessInvalidUtf8) {
+ const std::string test_model = ReadFile(GetTestModelPath());
+ const std::string invalid_utf8_text_with_phone_number =
+ "(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80";
+
+ std::unique_ptr<Annotator> classifier =
+ Annotator::FromString(test_model, unilib_.get(), calendarlib_.get());
+ ASSERT_TRUE(classifier);
+ EXPECT_THAT(classifier->Annotate(invalid_utf8_text_with_phone_number),
+ IsEmpty());
+ EXPECT_THAT(
+ classifier->SuggestSelection(invalid_utf8_text_with_phone_number, {1, 4}),
+ Eq(CodepointSpan{1, 4}));
+ EXPECT_THAT(
+ classifier->ClassifyText(invalid_utf8_text_with_phone_number, {0, 14}),
+ IsEmpty());
+}
+
} // namespace test_internal
} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/datetime-grounder.cc b/native/annotator/datetime/datetime-grounder.cc
deleted file mode 100644
index c6d2a66..0000000
--- a/native/annotator/datetime/datetime-grounder.cc
+++ /dev/null
@@ -1,213 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/datetime/datetime-grounder.h"
-
-#include <vector>
-
-#include "annotator/datetime/datetime_generated.h"
-#include "annotator/datetime/utils.h"
-#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/base/status.h"
-#include "utils/base/status_macros.h"
-
-using ::libtextclassifier3::grammar::datetime::AbsoluteDateTime;
-using ::libtextclassifier3::grammar::datetime::ComponentType;
-using ::libtextclassifier3::grammar::datetime::Meridiem;
-using ::libtextclassifier3::grammar::datetime::RelativeDateTime;
-using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent;
-using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
-using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent_::
- Modifier;
-
-namespace libtextclassifier3 {
-
-namespace {
-
-StatusOr<DatetimeComponent::RelativeQualifier> ToRelativeQualifier(
- const Modifier& modifier) {
- switch (modifier) {
- case Modifier::Modifier_THIS:
- return DatetimeComponent::RelativeQualifier::THIS;
- case Modifier::Modifier_LAST:
- return DatetimeComponent::RelativeQualifier::LAST;
- case Modifier::Modifier_NEXT:
- return DatetimeComponent::RelativeQualifier::NEXT;
- case Modifier::Modifier_NOW:
- return DatetimeComponent::RelativeQualifier::NOW;
- case Modifier::Modifier_TOMORROW:
- return DatetimeComponent::RelativeQualifier::TOMORROW;
- case Modifier::Modifier_YESTERDAY:
- return DatetimeComponent::RelativeQualifier::YESTERDAY;
- case Modifier::Modifier_UNSPECIFIED:
- return DatetimeComponent::RelativeQualifier::UNSPECIFIED;
- default:
- return Status(StatusCode::INTERNAL,
- "Couldn't parse the Modifier to RelativeQualifier.");
- }
-}
-
-StatusOr<DatetimeComponent::ComponentType> ToComponentType(
- const grammar::datetime::ComponentType component_type) {
- switch (component_type) {
- case grammar::datetime::ComponentType_YEAR:
- return DatetimeComponent::ComponentType::YEAR;
- case grammar::datetime::ComponentType_MONTH:
- return DatetimeComponent::ComponentType::MONTH;
- case grammar::datetime::ComponentType_WEEK:
- return DatetimeComponent::ComponentType::WEEK;
- case grammar::datetime::ComponentType_DAY_OF_WEEK:
- return DatetimeComponent::ComponentType::DAY_OF_WEEK;
- case grammar::datetime::ComponentType_DAY_OF_MONTH:
- return DatetimeComponent::ComponentType::DAY_OF_MONTH;
- case grammar::datetime::ComponentType_HOUR:
- return DatetimeComponent::ComponentType::HOUR;
- case grammar::datetime::ComponentType_MINUTE:
- return DatetimeComponent::ComponentType::MINUTE;
- case grammar::datetime::ComponentType_SECOND:
- return DatetimeComponent::ComponentType::SECOND;
- case grammar::datetime::ComponentType_MERIDIEM:
- return DatetimeComponent::ComponentType::MERIDIEM;
- case grammar::datetime::ComponentType_UNSPECIFIED:
- return DatetimeComponent::ComponentType::UNSPECIFIED;
- default:
- return Status(StatusCode::INTERNAL,
- "Couldn't parse the DatetimeComponent's ComponentType from "
- "grammar's datetime ComponentType.");
- }
-}
-
-void FillAbsoluteDateTimeComponents(
- const grammar::datetime::AbsoluteDateTime* absolute_datetime,
- DatetimeParsedData* datetime_parsed_data) {
- if (absolute_datetime->year() >= 0) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::YEAR, absolute_datetime->year());
- }
- if (absolute_datetime->month() >= 0) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::MONTH, absolute_datetime->month());
- }
- if (absolute_datetime->day() >= 0) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::DAY_OF_MONTH,
- absolute_datetime->day());
- }
- if (absolute_datetime->week_day() >= 0) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::DAY_OF_WEEK,
- absolute_datetime->week_day());
- }
- if (absolute_datetime->hour() >= 0) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::HOUR, absolute_datetime->hour());
- }
- if (absolute_datetime->minute() >= 0) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::MINUTE, absolute_datetime->minute());
- }
- if (absolute_datetime->second() >= 0) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::SECOND, absolute_datetime->second());
- }
- if (absolute_datetime->meridiem() != grammar::datetime::Meridiem_UNKNOWN) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::MERIDIEM,
- absolute_datetime->meridiem() == grammar::datetime::Meridiem_AM ? 0
- : 1);
- }
- if (absolute_datetime->time_zone()) {
- datetime_parsed_data->SetAbsoluteValue(
- DatetimeComponent::ComponentType::ZONE_OFFSET,
- absolute_datetime->time_zone()->utc_offset_mins());
- }
-}
-
-StatusOr<DatetimeParsedData> FillRelativeDateTimeComponents(
- const grammar::datetime::RelativeDateTime* relative_datetime) {
- DatetimeParsedData datetime_parsed_data;
- for (const RelativeDatetimeComponent* relative_component :
- *relative_datetime->relative_datetime_component()) {
- TC3_ASSIGN_OR_RETURN(const DatetimeComponent::ComponentType component_type,
- ToComponentType(relative_component->component_type()));
- datetime_parsed_data.SetRelativeCount(component_type,
- relative_component->value());
- TC3_ASSIGN_OR_RETURN(
- const DatetimeComponent::RelativeQualifier relative_qualifier,
- ToRelativeQualifier(relative_component->modifier()));
- datetime_parsed_data.SetRelativeValue(component_type, relative_qualifier);
- }
- if (relative_datetime->base()) {
- FillAbsoluteDateTimeComponents(relative_datetime->base(),
- &datetime_parsed_data);
- }
- return datetime_parsed_data;
-}
-
-} // namespace
-
-DatetimeGrounder::DatetimeGrounder(const CalendarLib* calendarlib)
- : calendarlib_(*calendarlib) {}
-
-StatusOr<std::vector<DatetimeParseResult>> DatetimeGrounder::Ground(
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale,
- const grammar::datetime::UngroundedDatetime* ungrounded_datetime) const {
- DatetimeParsedData datetime_parsed_data;
- if (ungrounded_datetime->absolute_datetime()) {
- FillAbsoluteDateTimeComponents(ungrounded_datetime->absolute_datetime(),
- &datetime_parsed_data);
- } else if (ungrounded_datetime->relative_datetime()) {
- TC3_ASSIGN_OR_RETURN(datetime_parsed_data,
- FillRelativeDateTimeComponents(
- ungrounded_datetime->relative_datetime()));
- }
- std::vector<DatetimeParsedData> interpretations;
- FillInterpretations(datetime_parsed_data,
- calendarlib_.GetGranularity(datetime_parsed_data),
- &interpretations);
- std::vector<DatetimeParseResult> datetime_parse_result;
-
- for (const DatetimeParsedData& interpretation : interpretations) {
- std::vector<DatetimeComponent> date_components;
- interpretation.GetDatetimeComponents(&date_components);
- DatetimeParseResult result;
- // Text classifier only provides ambiguity limited to “AM/PM” which is
- // encoded in the pair of DatetimeParseResult; both corresponding to the
- // same date, but one corresponding to “AM” and the other one corresponding
- // to “PM”.
- if (!calendarlib_.InterpretParseData(
- interpretation, reference_time_ms_utc, reference_timezone,
- reference_locale, /*prefer_future_for_unspecified_date=*/true,
- &(result.time_ms_utc), &(result.granularity))) {
- return Status(
- StatusCode::INTERNAL,
- "Couldn't parse the UngroundedDatetime to DatetimeParseResult.");
- }
-
- // Sort the date time units by component type.
- std::sort(date_components.begin(), date_components.end(),
- [](DatetimeComponent a, DatetimeComponent b) {
- return a.component_type > b.component_type;
- });
- result.datetime_components.swap(date_components);
- datetime_parse_result.push_back(result);
- }
- return datetime_parse_result;
-}
-
-} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/datetime-grounder.h b/native/annotator/datetime/datetime-grounder.h
deleted file mode 100644
index 223d679..0000000
--- a/native/annotator/datetime/datetime-grounder.h
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_
-#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_
-
-#include <vector>
-
-#include "annotator/datetime/datetime_generated.h"
-#include "annotator/types.h"
-#include "utils/base/statusor.h"
-#include "utils/calendar/calendar.h"
-
-namespace libtextclassifier3 {
-
-// Utility class to resolve and complete an ungrounded datetime specification.
-class DatetimeGrounder {
- public:
- explicit DatetimeGrounder(const CalendarLib* calendarlib);
-
- // Resolves ambiguities and produces concrete datetime results from an
- // ungrounded datetime specification.
- StatusOr<std::vector<DatetimeParseResult>> Ground(
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale,
- const grammar::datetime::UngroundedDatetime* ungrounded_datetime) const;
-
- private:
- const CalendarLib& calendarlib_;
-};
-
-} // namespace libtextclassifier3
-
-#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_DATETIME_GROUNDER_H_
diff --git a/native/annotator/datetime/datetime-grounder_test.cc b/native/annotator/datetime/datetime-grounder_test.cc
deleted file mode 100644
index f53bcc6..0000000
--- a/native/annotator/datetime/datetime-grounder_test.cc
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include "annotator/datetime/datetime-grounder.h"
-
-#include "annotator/datetime/datetime_generated.h"
-#include "utils/flatbuffers/flatbuffers.h"
-#include "utils/jvm-test-utils.h"
-#include "gmock/gmock.h"
-#include "gtest/gtest.h"
-
-using ::libtextclassifier3::grammar::datetime::AbsoluteDateTimeT;
-using ::libtextclassifier3::grammar::datetime::ComponentType;
-using ::libtextclassifier3::grammar::datetime::Meridiem;
-using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponentT;
-using ::libtextclassifier3::grammar::datetime::RelativeDateTimeT;
-using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
-using ::libtextclassifier3::grammar::datetime::UngroundedDatetimeT;
-using ::libtextclassifier3::grammar::datetime::RelativeDatetimeComponent_::
- Modifier;
-using ::testing::SizeIs;
-
-namespace libtextclassifier3 {
-
-class DatetimeGrounderTest : public testing::Test {
- public:
- void SetUp() override {
- calendarlib_ = CreateCalendarLibForTesting();
- datetime_grounder_.reset(new DatetimeGrounder(calendarlib_.get()));
- }
-
- protected:
- OwnedFlatbuffer<UngroundedDatetime, std::string> BuildAbsoluteDatetime(
- const int year, const int month, const int day, const int hour,
- const int minute, const int second, const Meridiem meridiem) {
- grammar::datetime::UngroundedDatetimeT ungrounded_datetime;
- ungrounded_datetime.absolute_datetime.reset(new AbsoluteDateTimeT);
-
- // Set absolute datetime value.
- ungrounded_datetime.absolute_datetime->year = year;
- ungrounded_datetime.absolute_datetime->month = month;
- ungrounded_datetime.absolute_datetime->day = day;
- ungrounded_datetime.absolute_datetime->hour = hour;
- ungrounded_datetime.absolute_datetime->minute = minute;
- ungrounded_datetime.absolute_datetime->second = second;
- ungrounded_datetime.absolute_datetime->meridiem = meridiem;
-
- return OwnedFlatbuffer<UngroundedDatetime, std::string>(
- PackFlatbuffer<UngroundedDatetime>(&ungrounded_datetime));
- }
-
- OwnedFlatbuffer<UngroundedDatetime, std::string> BuildRelativeDatetime(
- const ComponentType component_type, const Modifier modifier,
- const int relative_count) {
- UngroundedDatetimeT ungrounded_datetime;
- ungrounded_datetime.relative_datetime.reset(new RelativeDateTimeT);
- ungrounded_datetime.relative_datetime->relative_datetime_component
- .emplace_back(new RelativeDatetimeComponentT);
- ungrounded_datetime.relative_datetime->relative_datetime_component.back()
- ->modifier = modifier;
- ungrounded_datetime.relative_datetime->relative_datetime_component.back()
- ->component_type = component_type;
- ungrounded_datetime.relative_datetime->relative_datetime_component.back()
- ->value = relative_count;
- ungrounded_datetime.relative_datetime->base.reset(new AbsoluteDateTimeT);
- ungrounded_datetime.relative_datetime->base->year = 2020;
- ungrounded_datetime.relative_datetime->base->month = 6;
- ungrounded_datetime.relative_datetime->base->day = 30;
-
- return OwnedFlatbuffer<UngroundedDatetime, std::string>(
- PackFlatbuffer<UngroundedDatetime>(&ungrounded_datetime));
- }
-
- std::unique_ptr<DatetimeGrounder> datetime_grounder_;
- std::unique_ptr<CalendarLib> calendarlib_;
-};
-
-TEST_F(DatetimeGrounderTest, AbsoluteDatetimeTest) {
- const OwnedFlatbuffer<UngroundedDatetime, std::string> datetime =
- BuildAbsoluteDatetime(/*year=*/2000, /*month=*/03, /*day=*/30,
- /*hour=*/11, /*minute=*/59, /*second=*/59,
- grammar::datetime::Meridiem_AM);
- const std::vector<DatetimeParseResult> data =
- datetime_grounder_
- ->Ground(
- /*reference_time_ms_utc=*/0, "Europe/Zurich", "en-US",
- datetime.get())
- .ValueOrDie();
-
- EXPECT_THAT(data, SizeIs(1));
- EXPECT_EQ(data[0].granularity, DatetimeGranularity::GRANULARITY_SECOND);
-
- // Meridiem
- EXPECT_EQ(data[0].datetime_components[0].component_type,
- DatetimeComponent::ComponentType::MERIDIEM);
- EXPECT_EQ(data[0].datetime_components[0].value, 0);
-
- EXPECT_EQ(data[0].datetime_components[1].component_type,
- DatetimeComponent::ComponentType::SECOND);
- EXPECT_EQ(data[0].datetime_components[1].component_type,
- DatetimeComponent::ComponentType::SECOND);
-
- EXPECT_EQ(data[0].datetime_components[2].component_type,
- DatetimeComponent::ComponentType::MINUTE);
- EXPECT_EQ(data[0].datetime_components[2].value, 59);
-
- EXPECT_EQ(data[0].datetime_components[3].component_type,
- DatetimeComponent::ComponentType::HOUR);
- EXPECT_EQ(data[0].datetime_components[3].value, 11);
-
- EXPECT_EQ(data[0].datetime_components[4].component_type,
- DatetimeComponent::ComponentType::DAY_OF_MONTH);
- EXPECT_EQ(data[0].datetime_components[4].value, 30);
-
- EXPECT_EQ(data[0].datetime_components[5].component_type,
- DatetimeComponent::ComponentType::MONTH);
- EXPECT_EQ(data[0].datetime_components[5].value, 3);
-
- EXPECT_EQ(data[0].datetime_components[6].component_type,
- DatetimeComponent::ComponentType::YEAR);
- EXPECT_EQ(data[0].datetime_components[6].value, 2000);
-}
-
-TEST_F(DatetimeGrounderTest, InterpretDatetimeTest) {
- const OwnedFlatbuffer<UngroundedDatetime, std::string> datetime =
- BuildAbsoluteDatetime(/*year=*/2000, /*month=*/03, /*day=*/30,
- /*hour=*/11, /*minute=*/59, /*second=*/59,
- grammar::datetime::Meridiem_UNKNOWN);
- const std::vector<DatetimeParseResult> data =
- datetime_grounder_
- ->Ground(
- /*reference_time_ms_utc=*/0, "Europe/Zurich", "en-US",
- datetime.get())
- .ValueOrDie();
-
- EXPECT_THAT(data, SizeIs(2));
- EXPECT_EQ(data[0].granularity, DatetimeGranularity::GRANULARITY_SECOND);
- EXPECT_EQ(data[1].granularity, DatetimeGranularity::GRANULARITY_SECOND);
-
- // Check Meridiem's values
- EXPECT_EQ(data[0].datetime_components[0].component_type,
- DatetimeComponent::ComponentType::MERIDIEM);
- EXPECT_EQ(data[0].datetime_components[0].value, 0);
- EXPECT_EQ(data[1].datetime_components[0].component_type,
- DatetimeComponent::ComponentType::MERIDIEM);
- EXPECT_EQ(data[1].datetime_components[0].value, 1);
-}
-
-TEST_F(DatetimeGrounderTest, RelativeDatetimeTest) {
- const OwnedFlatbuffer<UngroundedDatetime, std::string> datetime =
- BuildRelativeDatetime(ComponentType::ComponentType_DAY_OF_MONTH,
- Modifier::Modifier_NEXT, 1);
- const std::vector<DatetimeParseResult> data =
- datetime_grounder_
- ->Ground(
- /*reference_time_ms_utc=*/0, "Europe/Zurich", "en-US",
- datetime.get())
- .ValueOrDie();
-
- EXPECT_THAT(data, SizeIs(1));
- EXPECT_EQ(data[0].granularity, DatetimeGranularity::GRANULARITY_DAY);
-
- EXPECT_EQ(data[0].datetime_components[0].component_type,
- DatetimeComponent::ComponentType::DAY_OF_MONTH);
- EXPECT_EQ(data[0].datetime_components[0].relative_qualifier,
- DatetimeComponent::RelativeQualifier::NEXT);
- EXPECT_EQ(data[0].datetime_components[0].relative_count, 1);
- EXPECT_EQ(data[0].datetime_components[1].component_type,
- DatetimeComponent::ComponentType::MONTH);
- EXPECT_EQ(data[0].datetime_components[2].component_type,
- DatetimeComponent::ComponentType::YEAR);
-}
-
-} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/datetime.fbs b/native/annotator/datetime/datetime.fbs
index 77cbc25..8012cdc 100755
--- a/native/annotator/datetime/datetime.fbs
+++ b/native/annotator/datetime/datetime.fbs
@@ -62,7 +62,7 @@
namespace libtextclassifier3.grammar.datetime;
table TimeZone {
- // Offset from UTC/GTM in hours.
+ // Offset from UTC/GTM in minutes.
utc_offset_mins:int;
}
diff --git a/native/annotator/datetime/parser.h b/native/annotator/datetime/parser.h
index 8b58388..3b3e578 100644
--- a/native/annotator/datetime/parser.h
+++ b/native/annotator/datetime/parser.h
@@ -19,18 +19,13 @@
#include <memory>
#include <string>
-#include <unordered_map>
-#include <unordered_set>
#include <vector>
-#include "annotator/datetime/extractor.h"
-#include "annotator/model_generated.h"
#include "annotator/types.h"
-#include "utils/base/integral_types.h"
-#include "utils/calendar/calendar.h"
+#include "utils/base/statusor.h"
+#include "utils/i18n/locale-list.h"
+#include "utils/i18n/locale.h"
#include "utils/utf8/unicodetext.h"
-#include "utils/utf8/unilib.h"
-#include "utils/zlib/zlib.h"
namespace libtextclassifier3 {
@@ -38,87 +33,25 @@
// time.
class DatetimeParser {
public:
- static std::unique_ptr<DatetimeParser> Instance(
- const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib, ZlibDecompressor* decompressor);
+ virtual ~DatetimeParser() = default;
// Parses the dates in 'input' and fills result. Makes sure that the results
// do not overlap.
// If 'anchor_start_end' is true the extracted results need to start at the
// beginning of 'input' and end at the end of it.
- bool Parse(const std::string& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
+ virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const = 0;
// Same as above but takes UnicodeText.
- bool Parse(const UnicodeText& input, int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const;
-
- protected:
- explicit DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib,
- ZlibDecompressor* decompressor);
-
- // Returns a list of locale ids for given locale spec string (comma-separated
- // locale names). Assigns the first parsed locale to reference_locale.
- std::vector<int> ParseAndExpandLocales(const std::string& locales,
- std::string* reference_locale) const;
-
- // Helper function that finds datetime spans, only using the rules associated
- // with the given locales.
- bool FindSpansUsingLocales(
- const std::vector<int>& locale_ids, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ virtual StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
- bool anchor_start_end, const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const;
-
- bool ParseWithRule(const CompiledRule& rule, const UnicodeText& input,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- // Converts the current match in 'matcher' into DatetimeParseResult.
- bool ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const;
-
- // Parse and extract information from current match in 'matcher'.
- bool HandleParseMatch(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const;
-
- private:
- bool initialized_;
- const UniLib& unilib_;
- const CalendarLib& calendarlib_;
- std::vector<CompiledRule> rules_;
- std::unordered_map<int, std::vector<int>> locale_to_rules_;
- std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
- std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
- type_and_locale_to_extractor_rule_;
- std::unordered_map<std::string, int> locale_string_to_id_;
- std::vector<int> default_locale_ids_;
- bool use_extractors_for_locating_;
- bool generate_alternative_interpretations_when_ambiguous_;
- bool prefer_future_for_unspecified_date_;
+ bool anchor_start_end) const = 0;
};
-
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_PARSER_H_
diff --git a/native/annotator/datetime/parser.cc b/native/annotator/datetime/regex-parser.cc
similarity index 69%
rename from native/annotator/datetime/parser.cc
rename to native/annotator/datetime/regex-parser.cc
index 72fd3ab..4dc9c56 100644
--- a/native/annotator/datetime/parser.cc
+++ b/native/annotator/datetime/regex-parser.cc
@@ -14,33 +14,36 @@
* limitations under the License.
*/
-#include "annotator/datetime/parser.h"
+#include "annotator/datetime/regex-parser.h"
+#include <iterator>
#include <set>
#include <unordered_set>
#include "annotator/datetime/extractor.h"
#include "annotator/datetime/utils.h"
+#include "utils/base/statusor.h"
#include "utils/calendar/calendar.h"
#include "utils/i18n/locale.h"
#include "utils/strings/split.h"
#include "utils/zlib/zlib_regex.h"
namespace libtextclassifier3 {
-std::unique_ptr<DatetimeParser> DatetimeParser::Instance(
+std::unique_ptr<DatetimeParser> RegexDatetimeParser::Instance(
const DatetimeModel* model, const UniLib* unilib,
const CalendarLib* calendarlib, ZlibDecompressor* decompressor) {
- std::unique_ptr<DatetimeParser> result(
- new DatetimeParser(model, unilib, calendarlib, decompressor));
+ std::unique_ptr<RegexDatetimeParser> result(
+ new RegexDatetimeParser(model, unilib, calendarlib, decompressor));
if (!result->initialized_) {
result.reset();
}
return result;
}
-DatetimeParser::DatetimeParser(const DatetimeModel* model, const UniLib* unilib,
- const CalendarLib* calendarlib,
- ZlibDecompressor* decompressor)
+RegexDatetimeParser::RegexDatetimeParser(const DatetimeModel* model,
+ const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor)
: unilib_(*unilib), calendarlib_(*calendarlib) {
initialized_ = false;
@@ -113,23 +116,24 @@
initialized_ = true;
}
-bool DatetimeParser::Parse(
+StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
- reference_time_ms_utc, reference_timezone, locales, mode,
- annotation_usecase, anchor_start_end, results);
+ reference_time_ms_utc, reference_timezone, locale_list, mode,
+ annotation_usecase, anchor_start_end);
}
-bool DatetimeParser::FindSpansUsingLocales(
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::FindSpansUsingLocales(
const std::vector<int>& locale_ids, const UnicodeText& input,
const int64 reference_time_ms_utc, const std::string& reference_timezone,
ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
const std::string& reference_locale,
- std::unordered_set<int>* executed_rules,
- std::vector<DatetimeParseResultSpan>* found_spans) const {
+ std::unordered_set<int>* executed_rules) const {
+ std::vector<DatetimeParseResultSpan> found_spans;
for (const int locale_id : locale_ids) {
auto rules_it = locale_to_rules_.find(locale_id);
if (rules_it == locale_to_rules_.end()) {
@@ -152,34 +156,33 @@
}
executed_rules->insert(rule_id);
-
- if (!ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- anchor_start_end, found_spans)) {
- return false;
- }
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& found_spans_per_rule,
+ ParseWithRule(rules_[rule_id], input, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id,
+ anchor_start_end));
+ found_spans.insert(std::end(found_spans),
+ std::begin(found_spans_per_rule),
+ std::end(found_spans_per_rule));
}
}
- return true;
+ return found_spans;
}
-bool DatetimeParser::Parse(
+StatusOr<std::vector<DatetimeParseResultSpan>> RegexDatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
- const std::string& reference_timezone, const std::string& locales,
- ModeFlag mode, AnnotationUsecase annotation_usecase, bool anchor_start_end,
- std::vector<DatetimeParseResultSpan>* results) const {
- std::vector<DatetimeParseResultSpan> found_spans;
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const {
std::unordered_set<int> executed_rules;
- std::string reference_locale;
const std::vector<int> requested_locales =
- ParseAndExpandLocales(locales, &reference_locale);
- if (!FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
- reference_timezone, mode, annotation_usecase,
- anchor_start_end, reference_locale,
- &executed_rules, &found_spans)) {
- return false;
- }
-
+ ParseAndExpandLocales(locale_list.GetLocaleTags());
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& found_spans,
+ FindSpansUsingLocales(requested_locales, input, reference_time_ms_utc,
+ reference_timezone, mode, annotation_usecase,
+ anchor_start_end, locale_list.GetReferenceLocale(),
+ &executed_rules));
std::vector<std::pair<DatetimeParseResultSpan, int>> indexed_found_spans;
indexed_found_spans.reserve(found_spans.size());
for (int i = 0; i < found_spans.size(); i++) {
@@ -200,39 +203,46 @@
}
});
- found_spans.clear();
+ std::vector<DatetimeParseResultSpan> results;
+ std::vector<DatetimeParseResultSpan> resolved_found_spans;
+ resolved_found_spans.reserve(indexed_found_spans.size());
for (auto& span_index_pair : indexed_found_spans) {
- found_spans.push_back(span_index_pair.first);
+ resolved_found_spans.push_back(span_index_pair.first);
}
std::set<int, std::function<bool(int, int)>> chosen_indices_set(
- [&found_spans](int a, int b) {
- return found_spans[a].span.first < found_spans[b].span.first;
+ [&resolved_found_spans](int a, int b) {
+ return resolved_found_spans[a].span.first <
+ resolved_found_spans[b].span.first;
});
- for (int i = 0; i < found_spans.size(); ++i) {
- if (!DoesCandidateConflict(i, found_spans, chosen_indices_set)) {
+ for (int i = 0; i < resolved_found_spans.size(); ++i) {
+ if (!DoesCandidateConflict(i, resolved_found_spans, chosen_indices_set)) {
chosen_indices_set.insert(i);
- results->push_back(found_spans[i]);
+ results.push_back(resolved_found_spans[i]);
}
}
-
- return true;
+ return results;
}
-bool DatetimeParser::HandleParseMatch(
- const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
- int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, int locale_id,
- std::vector<DatetimeParseResultSpan>* result) const {
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::HandleParseMatch(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ int locale_id) const {
+ std::vector<DatetimeParseResultSpan> results;
int status = UniLib::RegexMatcher::kNoError;
const int start = matcher.Start(&status);
if (status != UniLib::RegexMatcher::kNoError) {
- return false;
+ return Status(StatusCode::INTERNAL,
+ "Failed to gets the start offset of the last match.");
}
const int end = matcher.End(&status);
if (status != UniLib::RegexMatcher::kNoError) {
- return false;
+ return Status(StatusCode::INTERNAL,
+ "Failed to gets the end offset of the last match.");
}
DatetimeParseResultSpan parse_result;
@@ -240,7 +250,7 @@
if (!ExtractDatetime(rule, matcher, reference_time_ms_utc, reference_timezone,
reference_locale, locale_id, &alternatives,
&parse_result.span)) {
- return false;
+ return Status(StatusCode::INTERNAL, "Failed to extract Datetime.");
}
if (!use_extractors_for_locating_) {
@@ -257,49 +267,44 @@
parse_result.data.push_back(alternative);
}
}
- result->push_back(parse_result);
- return true;
+ results.push_back(parse_result);
+ return results;
}
-bool DatetimeParser::ParseWithRule(
- const CompiledRule& rule, const UnicodeText& input,
- const int64 reference_time_ms_utc, const std::string& reference_timezone,
- const std::string& reference_locale, const int locale_id,
- bool anchor_start_end, std::vector<DatetimeParseResultSpan>* result) const {
+StatusOr<std::vector<DatetimeParseResultSpan>>
+RegexDatetimeParser::ParseWithRule(const CompiledRule& rule,
+ const UnicodeText& input,
+ const int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale,
+ const int locale_id,
+ bool anchor_start_end) const {
+ std::vector<DatetimeParseResultSpan> results;
std::unique_ptr<UniLib::RegexMatcher> matcher =
rule.compiled_regex->Matcher(input);
int status = UniLib::RegexMatcher::kNoError;
if (anchor_start_end) {
if (matcher->Matches(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
+ return HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id);
}
} else {
while (matcher->Find(&status) && status == UniLib::RegexMatcher::kNoError) {
- if (!HandleParseMatch(rule, *matcher, reference_time_ms_utc,
- reference_timezone, reference_locale, locale_id,
- result)) {
- return false;
- }
+ TC3_ASSIGN_OR_RETURN(
+ const std::vector<DatetimeParseResultSpan>& pattern_occurrence,
+ HandleParseMatch(rule, *matcher, reference_time_ms_utc,
+ reference_timezone, reference_locale, locale_id));
+ results.insert(std::end(results), std::begin(pattern_occurrence),
+ std::end(pattern_occurrence));
}
}
- return true;
+ return results;
}
-std::vector<int> DatetimeParser::ParseAndExpandLocales(
- const std::string& locales, std::string* reference_locale) const {
- std::vector<StringPiece> split_locales = strings::Split(locales, ',');
- if (!split_locales.empty()) {
- *reference_locale = split_locales[0].ToString();
- } else {
- *reference_locale = "";
- }
-
+std::vector<int> RegexDatetimeParser::ParseAndExpandLocales(
+ const std::vector<StringPiece>& locales) const {
std::vector<int> result;
- for (const StringPiece& locale_str : split_locales) {
+ for (const StringPiece& locale_str : locales) {
auto locale_it = locale_string_to_id_.find(locale_str.ToString());
if (locale_it != locale_string_to_id_.end()) {
result.push_back(locale_it->second);
@@ -348,14 +353,12 @@
return result;
}
-bool DatetimeParser::ExtractDatetime(const CompiledRule& rule,
- const UniLib::RegexMatcher& matcher,
- const int64 reference_time_ms_utc,
- const std::string& reference_timezone,
- const std::string& reference_locale,
- int locale_id,
- std::vector<DatetimeParseResult>* results,
- CodepointSpan* result_span) const {
+bool RegexDatetimeParser::ExtractDatetime(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const {
DatetimeParsedData parse;
DatetimeExtractor extractor(rule, matcher, locale_id, &unilib_,
extractor_rules_,
diff --git a/native/annotator/datetime/regex-parser.h b/native/annotator/datetime/regex-parser.h
new file mode 100644
index 0000000..e820c21
--- /dev/null
+++ b/native/annotator/datetime/regex-parser.h
@@ -0,0 +1,123 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "annotator/datetime/extractor.h"
+#include "annotator/datetime/parser.h"
+#include "annotator/model_generated.h"
+#include "annotator/types.h"
+#include "utils/base/integral_types.h"
+#include "utils/base/statusor.h"
+#include "utils/calendar/calendar.h"
+#include "utils/strings/stringpiece.h"
+#include "utils/utf8/unicodetext.h"
+#include "utils/utf8/unilib.h"
+#include "utils/zlib/zlib.h"
+
+namespace libtextclassifier3 {
+
+// Parses datetime expressions in the input and resolves them to actual absolute
+// time.
+class RegexDatetimeParser : public DatetimeParser {
+ public:
+ static std::unique_ptr<DatetimeParser> Instance(
+ const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib, ZlibDecompressor* decompressor);
+
+ // Parses the dates in 'input' and fills result. Makes sure that the results
+ // do not overlap.
+ // If 'anchor_start_end' is true the extracted results need to start at the
+ // beginning of 'input' and end at the end of it.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const std::string& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ // Same as above but takes UnicodeText.
+ StatusOr<std::vector<DatetimeParseResultSpan>> Parse(
+ const UnicodeText& input, int64 reference_time_ms_utc,
+ const std::string& reference_timezone, const LocaleList& locale_list,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end) const override;
+
+ protected:
+ explicit RegexDatetimeParser(const DatetimeModel* model, const UniLib* unilib,
+ const CalendarLib* calendarlib,
+ ZlibDecompressor* decompressor);
+
+ // Returns a list of locale ids for given locale spec string (collection of
+ // locale names).
+ std::vector<int> ParseAndExpandLocales(
+ const std::vector<StringPiece>& locales) const;
+
+ // Helper function that finds datetime spans, only using the rules associated
+ // with the given locales.
+ StatusOr<std::vector<DatetimeParseResultSpan>> FindSpansUsingLocales(
+ const std::vector<int>& locale_ids, const UnicodeText& input,
+ const int64 reference_time_ms_utc, const std::string& reference_timezone,
+ ModeFlag mode, AnnotationUsecase annotation_usecase,
+ bool anchor_start_end, const std::string& reference_locale,
+ std::unordered_set<int>* executed_rules) const;
+
+ StatusOr<std::vector<DatetimeParseResultSpan>> ParseWithRule(
+ const CompiledRule& rule, const UnicodeText& input,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, const int locale_id,
+ bool anchor_start_end) const;
+
+ // Converts the current match in 'matcher' into DatetimeParseResult.
+ bool ExtractDatetime(const CompiledRule& rule,
+ const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc,
+ const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id,
+ std::vector<DatetimeParseResult>* results,
+ CodepointSpan* result_span) const;
+
+ // Parse and extract information from current match in 'matcher'.
+ StatusOr<std::vector<DatetimeParseResultSpan>> HandleParseMatch(
+ const CompiledRule& rule, const UniLib::RegexMatcher& matcher,
+ int64 reference_time_ms_utc, const std::string& reference_timezone,
+ const std::string& reference_locale, int locale_id) const;
+
+ private:
+ bool initialized_;
+ const UniLib& unilib_;
+ const CalendarLib& calendarlib_;
+ std::vector<CompiledRule> rules_;
+ std::unordered_map<int, std::vector<int>> locale_to_rules_;
+ std::vector<std::unique_ptr<const UniLib::RegexPattern>> extractor_rules_;
+ std::unordered_map<DatetimeExtractorType, std::unordered_map<int, int>>
+ type_and_locale_to_extractor_rule_;
+ std::unordered_map<std::string, int> locale_string_to_id_;
+ std::vector<int> default_locale_ids_;
+ bool use_extractors_for_locating_;
+ bool generate_alternative_interpretations_when_ambiguous_;
+ bool prefer_future_for_unspecified_date_;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_REGEX_PARSER_H_
diff --git a/native/annotator/datetime/parser_test.cc b/native/annotator/datetime/regex-parser_test.cc
similarity index 88%
rename from native/annotator/datetime/parser_test.cc
rename to native/annotator/datetime/regex-parser_test.cc
index 3c6c858..a0d9adf 100644
--- a/native/annotator/datetime/parser_test.cc
+++ b/native/annotator/datetime/regex-parser_test.cc
@@ -14,7 +14,7 @@
* limitations under the License.
*/
-#include "annotator/datetime/parser.h"
+#include "annotator/datetime/regex-parser.h"
#include <time.h>
@@ -24,8 +24,11 @@
#include <string>
#include "annotator/annotator.h"
+#include "annotator/datetime/testing/base-parser-test.h"
+#include "annotator/datetime/testing/datetime-component-builder.h"
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
+#include "utils/i18n/locale-list.h"
#include "utils/jvm-test-utils.h"
#include "utils/test-data-test-utils.h"
#include "utils/testing/annotator.h"
@@ -33,48 +36,9 @@
#include "gtest/gtest.h"
using std::vector;
-using testing::ElementsAreArray;
namespace libtextclassifier3 {
namespace {
-// Builder class to construct the DatetimeComponents and make the test readable.
-class DatetimeComponentsBuilder {
- public:
- DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
- int value) {
- DatetimeComponent component;
- component.component_type = type;
- component.value = value;
- return AddComponent(component);
- }
-
- DatetimeComponentsBuilder Add(
- DatetimeComponent::ComponentType type, int value,
- DatetimeComponent::RelativeQualifier relative_qualifier,
- int relative_count) {
- DatetimeComponent component;
- component.component_type = type;
- component.value = value;
- component.relative_qualifier = relative_qualifier;
- component.relative_count = relative_count;
- return AddComponent(component);
- }
-
- std::vector<DatetimeComponent> Build() {
- std::vector<DatetimeComponent> result(datetime_components_);
- datetime_components_.clear();
- return result;
- }
-
- private:
- DatetimeComponentsBuilder AddComponent(
- const DatetimeComponent& datetime_component) {
- datetime_components_.push_back(datetime_component);
- return *this;
- }
- std::vector<DatetimeComponent> datetime_components_;
-};
-
std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
std::string ReadFile(const std::string& file_name) {
@@ -82,7 +46,7 @@
return std::string(std::istreambuf_iterator<char>(file_stream), {});
}
-class DateTimeParserTest : public testing::Test {
+class RegexDatetimeParserTest : public DateTimeParserTest {
public:
void SetUp() override {
// Loads default unmodified model. Individual tests can call LoadModel to
@@ -104,139 +68,12 @@
TC3_CHECK(parser_);
}
- bool HasNoResult(const std::string& text, bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- std::vector<DatetimeParseResultSpan> results;
- if (!parser_->Parse(text, 0, timezone, /*locales=*/"", ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- return results.empty();
+ // Exposes the date time parser for tests and evaluations.
+ const DatetimeParser* DatetimeParserForTests() const override {
+ return classifier_->DatetimeParserForTests();
}
- bool ParsesCorrectly(const std::string& marked_text,
- const vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- const UnicodeText marked_text_unicode =
- UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
- auto brace_open_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
- auto brace_end_it =
- std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
- TC3_CHECK(brace_open_it != marked_text_unicode.end());
- TC3_CHECK(brace_end_it != marked_text_unicode.end());
-
- std::string text;
- text +=
- UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
- text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
- text += UnicodeText::UTF8Substring(std::next(brace_end_it),
- marked_text_unicode.end());
-
- std::vector<DatetimeParseResultSpan> results;
-
- if (!parser_->Parse(text, 0, timezone, locales, ModeFlag_ANNOTATION,
- annotation_usecase, anchor_start_end, &results)) {
- TC3_LOG(ERROR) << text;
- TC3_CHECK(false);
- }
- if (results.empty()) {
- TC3_LOG(ERROR) << "No results.";
- return false;
- }
-
- const int expected_start_index =
- std::distance(marked_text_unicode.begin(), brace_open_it);
- // The -1 below is to account for the opening bracket character.
- const int expected_end_index =
- std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
-
- std::vector<DatetimeParseResultSpan> filtered_results;
- for (const DatetimeParseResultSpan& result : results) {
- if (SpansOverlap(result.span,
- {expected_start_index, expected_end_index})) {
- filtered_results.push_back(result);
- }
- }
- std::vector<DatetimeParseResultSpan> expected{
- {{expected_start_index, expected_end_index},
- {},
- /*target_classification_score=*/1.0,
- /*priority_score=*/1.0}};
- expected[0].data.resize(expected_ms_utcs.size());
- for (int i = 0; i < expected_ms_utcs.size(); i++) {
- expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
- datetime_components[i]};
- }
-
- const bool matches =
- testing::Matches(ElementsAreArray(expected))(filtered_results);
- if (!matches) {
- TC3_LOG(ERROR) << "Expected: " << expected[0];
- if (filtered_results.empty()) {
- TC3_LOG(ERROR) << "But got no results.";
- }
- TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
- }
-
- return matches;
- }
-
- bool ParsesCorrectly(const std::string& marked_text,
- const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components,
- bool anchor_start_end = false,
- const std::string& timezone = "Europe/Zurich",
- const std::string& locales = "en-US",
- AnnotationUsecase annotation_usecase =
- AnnotationUsecase_ANNOTATION_USECASE_SMART) {
- return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
- expected_granularity, datetime_components,
- anchor_start_end, timezone, locales,
- annotation_usecase);
- }
-
- bool ParsesCorrectlyGerman(
- const std::string& marked_text, const vector<int64>& expected_ms_utcs,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- bool ParsesCorrectlyGerman(
- const std::string& marked_text, const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"de");
- }
-
- bool ParsesCorrectlyChinese(
- const std::string& marked_text, const int64 expected_ms_utc,
- DatetimeGranularity expected_granularity,
- vector<vector<DatetimeComponent>> datetime_components) {
- return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
- datetime_components,
- /*anchor_start_end=*/false,
- /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
- }
-
- protected:
+ private:
std::string model_buffer_;
std::unique_ptr<Annotator> classifier_;
const DatetimeParser* parser_;
@@ -245,7 +82,7 @@
};
// Test with just a few cases to make debugging of general failures easier.
-TEST_F(DateTimeParserTest, ParseShort) {
+TEST_F(RegexDatetimeParserTest, ParseShort) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988}", 567990000000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -255,7 +92,7 @@
.Build()}));
}
-TEST_F(DateTimeParserTest, Parse) {
+TEST_F(RegexDatetimeParserTest, Parse) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988}", 567990000000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -696,7 +533,7 @@
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
}
-TEST_F(DateTimeParserTest, ParseWithAnchor) {
+TEST_F(RegexDatetimeParserTest, ParseWithAnchor) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988}", 567990000000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -725,7 +562,7 @@
/*anchor_start_end=*/true));
}
-TEST_F(DateTimeParserTest, ParseWithRawUsecase) {
+TEST_F(RegexDatetimeParserTest, ParseWithRawUsecase) {
// Annotated for RAW usecase.
EXPECT_TRUE(ParsesCorrectly(
"{tomorrow}", 82800000, GRANULARITY_DAY,
@@ -784,7 +621,7 @@
}
// For details please see b/155437137
-TEST_F(DateTimeParserTest, PastRelativeDatetime) {
+TEST_F(RegexDatetimeParserTest, PastRelativeDatetime) {
EXPECT_TRUE(ParsesCorrectly(
"called you {last Saturday}",
-432000000 /* Fri 1969-12-26 16:00:00 PST */, GRANULARITY_DAY,
@@ -830,7 +667,7 @@
/*annotation_usecase=*/AnnotationUsecase_ANNOTATION_USECASE_RAW));
}
-TEST_F(DateTimeParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
+TEST_F(RegexDatetimeParserTest, AddsADayWhenTimeInThePastAndDayNotSpecified) {
// ParsesCorrectly uses 0 as the reference time, which corresponds to:
// "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
// it is in the past, and so the parser should move this to the next day ->
@@ -845,7 +682,7 @@
.Build()}));
}
-TEST_F(DateTimeParserTest,
+TEST_F(RegexDatetimeParserTest,
DoesNotAddADayWhenTimeInThePastAndDayNotSpecifiedDisabled) {
// ParsesCorrectly uses 0 as the reference time, which corresponds to:
// "Thu Jan 01 1970 01:00:00" Zurich time. So if we pass "0:30" here, it means
@@ -868,7 +705,7 @@
.Build()}));
}
-TEST_F(DateTimeParserTest, ParsesNoonAndMidnightCorrectly) {
+TEST_F(RegexDatetimeParserTest, ParsesNoonAndMidnightCorrectly) {
EXPECT_TRUE(ParsesCorrectly(
"{January 1, 1988 12:30am}", 567991800000, GRANULARITY_MINUTE,
{DatetimeComponentsBuilder()
@@ -900,7 +737,7 @@
.Build()}));
}
-TEST_F(DateTimeParserTest, ParseGerman) {
+TEST_F(RegexDatetimeParserTest, ParseGerman) {
EXPECT_TRUE(ParsesCorrectlyGerman(
"{Januar 1 2018}", 1514761200000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -1310,7 +1147,7 @@
.Build()}));
}
-TEST_F(DateTimeParserTest, ParseChinese) {
+TEST_F(RegexDatetimeParserTest, ParseChinese) {
EXPECT_TRUE(ParsesCorrectlyChinese(
"{明天 7 上午}", 108000000, GRANULARITY_HOUR,
{DatetimeComponentsBuilder()
@@ -1321,7 +1158,7 @@
.Build()}));
}
-TEST_F(DateTimeParserTest, ParseNonUs) {
+TEST_F(RegexDatetimeParserTest, ParseNonUs) {
auto first_may_2015 =
DatetimeComponentsBuilder()
.Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 1)
@@ -1340,7 +1177,7 @@
/*timezone=*/"Europe/Zurich", /*locales=*/"en"));
}
-TEST_F(DateTimeParserTest, ParseUs) {
+TEST_F(RegexDatetimeParserTest, ParseUs) {
auto five_january_2015 =
DatetimeComponentsBuilder()
.Add(DatetimeComponent::ComponentType::DAY_OF_MONTH, 5)
@@ -1360,7 +1197,7 @@
/*locales=*/"es-US"));
}
-TEST_F(DateTimeParserTest, ParseUnknownLanguage) {
+TEST_F(RegexDatetimeParserTest, ParseUnknownLanguage) {
EXPECT_TRUE(ParsesCorrectly(
"bylo to {31. 12. 2015} v 6 hodin", 1451516400000, GRANULARITY_DAY,
{DatetimeComponentsBuilder()
@@ -1372,7 +1209,7 @@
/*timezone=*/"Europe/Zurich", /*locales=*/"xx"));
}
-TEST_F(DateTimeParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
+TEST_F(RegexDatetimeParserTest, WhenAlternativesEnabledGeneratesAlternatives) {
LoadModel([](ModelT* model) {
model->datetime_model->generate_alternative_interpretations_when_ambiguous =
true;
@@ -1423,7 +1260,7 @@
.Build()}));
}
-TEST_F(DateTimeParserTest,
+TEST_F(RegexDatetimeParserTest,
WhenAlternativesDisabledDoesNotGenerateAlternatives) {
LoadModel([](ModelT* model) {
model->datetime_model->generate_alternative_interpretations_when_ambiguous =
@@ -1492,19 +1329,19 @@
unilib_ = CreateUniLibForTesting();
calendarlib_ = CreateCalendarLibForTesting();
parser_ =
- DatetimeParser::Instance(model_fb, unilib_.get(), calendarlib_.get(),
- /*decompressor=*/nullptr);
+ RegexDatetimeParser::Instance(model_fb, unilib_.get(), calendarlib_.get(),
+ /*decompressor=*/nullptr);
ASSERT_TRUE(parser_);
}
bool ParserLocaleTest::HasResult(const std::string& input,
const std::string& locales) {
- std::vector<DatetimeParseResultSpan> results;
- EXPECT_TRUE(parser_->Parse(
+ StatusOr<std::vector<DatetimeParseResultSpan>> results = parser_->Parse(
input, /*reference_time_ms_utc=*/0,
- /*reference_timezone=*/"", locales, ModeFlag_ANNOTATION,
- AnnotationUsecase_ANNOTATION_USECASE_SMART, false, &results));
- return results.size() == 1;
+ /*reference_timezone=*/"", LocaleList::ParseFrom(locales),
+ ModeFlag_ANNOTATION, AnnotationUsecase_ANNOTATION_USECASE_SMART, false);
+ EXPECT_TRUE(results.ok());
+ return results.ValueOrDie().size() == 1;
}
TEST_F(ParserLocaleTest, English) {
diff --git a/native/annotator/datetime/testing/base-parser-test.cc b/native/annotator/datetime/testing/base-parser-test.cc
new file mode 100644
index 0000000..d8dd723
--- /dev/null
+++ b/native/annotator/datetime/testing/base-parser-test.cc
@@ -0,0 +1,162 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/testing/base-parser-test.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "utils/i18n/locale-list.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using std::vector;
+using testing::ElementsAreArray;
+
+namespace libtextclassifier3 {
+
+bool DateTimeParserTest::HasNoResult(const std::string& text,
+ bool anchor_start_end,
+ const std::string& timezone,
+ AnnotationUsecase annotation_usecase) {
+ StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
+ DatetimeParserForTests()->Parse(
+ text, 0, timezone, LocaleList::ParseFrom(/*locale_tags=*/""),
+ ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
+ if (!results_status.ok()) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ return results_status.ValueOrDie().empty();
+}
+
+bool DateTimeParserTest::ParsesCorrectly(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end, const std::string& timezone,
+ const std::string& locales, AnnotationUsecase annotation_usecase) {
+ const UnicodeText marked_text_unicode =
+ UTF8ToUnicodeText(marked_text, /*do_copy=*/false);
+ auto brace_open_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '{');
+ auto brace_end_it =
+ std::find(marked_text_unicode.begin(), marked_text_unicode.end(), '}');
+ TC3_CHECK(brace_open_it != marked_text_unicode.end());
+ TC3_CHECK(brace_end_it != marked_text_unicode.end());
+
+ std::string text;
+ text +=
+ UnicodeText::UTF8Substring(marked_text_unicode.begin(), brace_open_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_open_it), brace_end_it);
+ text += UnicodeText::UTF8Substring(std::next(brace_end_it),
+ marked_text_unicode.end());
+
+ StatusOr<std::vector<DatetimeParseResultSpan>> results_status =
+ DatetimeParserForTests()->Parse(
+ text, 0, timezone, LocaleList::ParseFrom(locales),
+ ModeFlag_ANNOTATION, annotation_usecase, anchor_start_end);
+ if (!results_status.ok()) {
+ TC3_LOG(ERROR) << text;
+ TC3_CHECK(false);
+ }
+ // const std::vector<DatetimeParseResultSpan>& results =
+ // results_status.ValueOrDie();
+ if (results_status.ValueOrDie().empty()) {
+ TC3_LOG(ERROR) << "No results.";
+ return false;
+ }
+
+ const int expected_start_index =
+ std::distance(marked_text_unicode.begin(), brace_open_it);
+ // The -1 below is to account for the opening bracket character.
+ const int expected_end_index =
+ std::distance(marked_text_unicode.begin(), brace_end_it) - 1;
+
+ std::vector<DatetimeParseResultSpan> filtered_results;
+ for (const DatetimeParseResultSpan& result : results_status.ValueOrDie()) {
+ if (SpansOverlap(result.span, {expected_start_index, expected_end_index})) {
+ filtered_results.push_back(result);
+ }
+ }
+ std::vector<DatetimeParseResultSpan> expected{
+ {{expected_start_index, expected_end_index},
+ {},
+ /*target_classification_score=*/1.0,
+ /*priority_score=*/1.0}};
+ expected[0].data.resize(expected_ms_utcs.size());
+ for (int i = 0; i < expected_ms_utcs.size(); i++) {
+ expected[0].data[i] = {expected_ms_utcs[i], expected_granularity,
+ datetime_components[i]};
+ }
+
+ const bool matches =
+ testing::Matches(ElementsAreArray(expected))(filtered_results);
+ if (!matches) {
+ TC3_LOG(ERROR) << "Expected: " << expected[0];
+ if (filtered_results.empty()) {
+ TC3_LOG(ERROR) << "But got no results.";
+ }
+ TC3_LOG(ERROR) << "Actual: " << filtered_results[0];
+ }
+
+ return matches;
+}
+
+bool DateTimeParserTest::ParsesCorrectly(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end, const std::string& timezone,
+ const std::string& locales, AnnotationUsecase annotation_usecase) {
+ return ParsesCorrectly(marked_text, vector<int64>{expected_ms_utc},
+ expected_granularity, datetime_components,
+ anchor_start_end, timezone, locales,
+ annotation_usecase);
+}
+
+bool DateTimeParserTest::ParsesCorrectlyGerman(
+ const std::string& marked_text, const vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utcs, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+}
+
+bool DateTimeParserTest::ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"de");
+}
+
+bool DateTimeParserTest::ParsesCorrectlyChinese(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ vector<vector<DatetimeComponent>> datetime_components) {
+ return ParsesCorrectly(marked_text, expected_ms_utc, expected_granularity,
+ datetime_components,
+ /*anchor_start_end=*/false,
+ /*timezone=*/"Europe/Zurich", /*locales=*/"zh");
+}
+
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/testing/base-parser-test.h b/native/annotator/datetime/testing/base-parser-test.h
new file mode 100644
index 0000000..3465a04
--- /dev/null
+++ b/native/annotator/datetime/testing/base-parser-test.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
+
+#include <string>
+#include <vector>
+
+#include "annotator/datetime/parser.h"
+#include "annotator/datetime/testing/base-parser-test.h"
+#include "annotator/types.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier3 {
+
+class DateTimeParserTest : public testing::Test {
+ public:
+ bool HasNoResult(const std::string& text, bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectly(
+ const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectly(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components,
+ bool anchor_start_end = false,
+ const std::string& timezone = "Europe/Zurich",
+ const std::string& locales = "en-US",
+ AnnotationUsecase annotation_usecase =
+ AnnotationUsecase_ANNOTATION_USECASE_SMART);
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text,
+ const std::vector<int64>& expected_ms_utcs,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ bool ParsesCorrectlyGerman(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ bool ParsesCorrectlyChinese(
+ const std::string& marked_text, const int64 expected_ms_utc,
+ DatetimeGranularity expected_granularity,
+ std::vector<std::vector<DatetimeComponent>> datetime_components);
+
+ // Exposes the date time parser for tests and evaluations.
+ virtual const DatetimeParser* DatetimeParserForTests() const = 0;
+};
+
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_BASE_PARSER_TEST_H_
diff --git a/native/annotator/datetime/testing/datetime-component-builder.cc b/native/annotator/datetime/testing/datetime-component-builder.cc
new file mode 100644
index 0000000..f0764da
--- /dev/null
+++ b/native/annotator/datetime/testing/datetime-component-builder.cc
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "annotator/datetime/testing/datetime-component-builder.h"
+
+namespace libtextclassifier3 {
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::Add(
+ DatetimeComponent::ComponentType type, int value) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ return AddComponent(component);
+}
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count) {
+ DatetimeComponent component;
+ component.component_type = type;
+ component.value = value;
+ component.relative_qualifier = relative_qualifier;
+ component.relative_count = relative_count;
+ return AddComponent(component);
+}
+
+std::vector<DatetimeComponent> DatetimeComponentsBuilder::Build() {
+ return std::move(datetime_components_);
+}
+
+DatetimeComponentsBuilder DatetimeComponentsBuilder::AddComponent(
+ const DatetimeComponent& datetime_component) {
+ datetime_components_.push_back(datetime_component);
+ return *this;
+}
+} // namespace libtextclassifier3
diff --git a/native/annotator/datetime/testing/datetime-component-builder.h b/native/annotator/datetime/testing/datetime-component-builder.h
new file mode 100644
index 0000000..a6a9f36
--- /dev/null
+++ b/native/annotator/datetime/testing/datetime-component-builder.h
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
+#define LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
+
+#include "annotator/types.h"
+
+namespace libtextclassifier3 {
+
+// Builder class to construct the DatetimeComponents and make the test readable.
+class DatetimeComponentsBuilder {
+ public:
+ DatetimeComponentsBuilder Add(DatetimeComponent::ComponentType type,
+ int value);
+
+ DatetimeComponentsBuilder Add(
+ DatetimeComponent::ComponentType type, int value,
+ DatetimeComponent::RelativeQualifier relative_qualifier,
+ int relative_count);
+
+ std::vector<DatetimeComponent> Build();
+
+ private:
+ DatetimeComponentsBuilder AddComponent(
+ const DatetimeComponent& datetime_component);
+ std::vector<DatetimeComponent> datetime_components_;
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_ANNOTATOR_DATETIME_TESTING_DATETIME_COMPONENT_BUILDER_H_
diff --git a/native/annotator/knowledge/knowledge-engine-dummy.h b/native/annotator/knowledge/knowledge-engine-dummy.h
index 2a53288..615ad06 100644
--- a/native/annotator/knowledge/knowledge-engine-dummy.h
+++ b/native/annotator/knowledge/knowledge-engine-dummy.h
@@ -52,12 +52,13 @@
return true;
}
- Status ChunkMultipleSpans(const std::vector<std::string>& text_fragments,
- AnnotationUsecase annotation_usecase,
- const Optional<LocationContext>& location_context,
- const Permissions& permissions,
- const AnnotateMode annotate_mode,
- Annotations* results) const {
+ Status ChunkMultipleSpans(
+ const std::vector<std::string>& text_fragments,
+ const std::vector<FragmentMetadata>& fragment_metadata,
+ AnnotationUsecase annotation_usecase,
+ const Optional<LocationContext>& location_context,
+ const Permissions& permissions, const AnnotateMode annotate_mode,
+ Annotations* results) const {
return Status::OK;
}
diff --git a/native/annotator/knowledge/knowledge-engine-types.h b/native/annotator/knowledge/knowledge-engine-types.h
index 9508c7b..04b71cb 100644
--- a/native/annotator/knowledge/knowledge-engine-types.h
+++ b/native/annotator/knowledge/knowledge-engine-types.h
@@ -21,6 +21,11 @@
enum AnnotateMode { kEntityAnnotation, kEntityAndTopicalityAnnotation };
+struct FragmentMetadata {
+ float relative_bounding_box_top;
+ float relative_bounding_box_height;
+};
+
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_KNOWLEDGE_KNOWLEDGE_ENGINE_TYPES_H_
diff --git a/native/annotator/model.fbs b/native/annotator/model.fbs
index 3197e58..263b122 100755
--- a/native/annotator/model.fbs
+++ b/native/annotator/model.fbs
@@ -14,17 +14,17 @@
// limitations under the License.
//
-include "utils/codepoint-range.fbs";
+include "utils/zlib/buffer.fbs";
+include "utils/flatbuffers/flatbuffers.fbs";
+include "utils/grammar/rules.fbs";
+include "utils/container/bit-vector.fbs";
+include "utils/tokenizer.fbs";
include "utils/normalization.fbs";
-include "annotator/entity-data.fbs";
+include "utils/codepoint-range.fbs";
include "annotator/experimental/experimental.fbs";
include "utils/intents/intent-config.fbs";
-include "utils/tokenizer.fbs";
include "utils/resources.fbs";
-include "utils/grammar/rules.fbs";
-include "utils/zlib/buffer.fbs";
-include "utils/container/bit-vector.fbs";
-include "utils/flatbuffers/flatbuffers.fbs";
+include "annotator/entity-data.fbs";
file_identifier "TC2 ";
diff --git a/native/annotator/types.h b/native/annotator/types.h
index 732e18d..3063838 100644
--- a/native/annotator/types.h
+++ b/native/annotator/types.h
@@ -431,7 +431,8 @@
std::string serialized_knowledge_result;
ContactPointer contact_pointer;
std::string contact_name, contact_given_name, contact_family_name,
- contact_nickname, contact_email_address, contact_phone_number, contact_id;
+ contact_nickname, contact_email_address, contact_phone_number,
+ contact_account_type, contact_account_name, contact_id;
std::string app_name, app_package_name;
int64 numeric_value;
double numeric_double_value;
@@ -684,6 +685,8 @@
struct InputFragment {
std::string text;
+ float bounding_box_top;
+ float bounding_box_height;
// If present will override the AnnotationOptions reference time and timezone
// when annotating this specific string fragment.
diff --git a/native/lang_id/common/file/mmap.cc b/native/lang_id/common/file/mmap.cc
index 444d0d0..19afcc4 100644
--- a/native/lang_id/common/file/mmap.cc
+++ b/native/lang_id/common/file/mmap.cc
@@ -160,6 +160,7 @@
SAFTM_LOG(ERROR) << "Error closing file descriptor: " << last_error;
}
}
+
private:
const int fd_;
@@ -199,12 +200,19 @@
}
MmapHandle MmapFile(int fd, size_t offset_in_bytes, size_t size_in_bytes) {
+ // Make sure the offset is a multiple of the page size, as returned by
+ // sysconf(_SC_PAGE_SIZE); this is required by the man-page for mmap.
+ static const size_t kPageSize = sysconf(_SC_PAGE_SIZE);
+ const size_t aligned_offset = (offset_in_bytes / kPageSize) * kPageSize;
+ const size_t alignment_shift = offset_in_bytes - aligned_offset;
+ const size_t aligned_length = size_in_bytes + alignment_shift;
+
void *mmap_addr = mmap(
// Let system pick address for mmapp-ed data.
nullptr,
- size_in_bytes,
+ aligned_length,
// One can read / write the mapped data (but see MAP_PRIVATE below).
// Normally, we expect only to read it, but in the future, we may want to
@@ -218,14 +226,15 @@
// Descriptor of file to mmap.
fd,
- offset_in_bytes);
+ aligned_offset);
if (mmap_addr == MAP_FAILED) {
const std::string last_error = GetLastSystemError();
SAFTM_LOG(ERROR) << "Error while mmapping: " << last_error;
return GetErrorMmapHandle();
}
- return MmapHandle(mmap_addr, size_in_bytes);
+ return MmapHandle(static_cast<char *>(mmap_addr) + alignment_shift,
+ size_in_bytes);
}
bool Unmap(MmapHandle mmap_handle) {
diff --git a/native/lang_id/lang-id_jni.cc b/native/lang_id/lang-id_jni.cc
index 84347cd..e86f198 100644
--- a/native/lang_id/lang-id_jni.cc
+++ b/native/lang_id/lang-id_jni.cc
@@ -176,3 +176,13 @@
LangId* model = reinterpret_cast<LangId*>(ptr);
return model->GetFloatProperty("min_text_size_in_bytes", 0);
}
+
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size) {
+ std::unique_ptr<LangId> lang_id =
+ GetLangIdFromFlatbufferFileDescriptor(fd, offset, size);
+ if (!lang_id->is_valid()) {
+ return -1;
+ }
+ return lang_id->GetModelVersion();
+}
diff --git a/native/lang_id/lang-id_jni.h b/native/lang_id/lang-id_jni.h
index e3ba610..e917197 100644
--- a/native/lang_id/lang-id_jni.h
+++ b/native/lang_id/lang-id_jni.h
@@ -20,7 +20,9 @@
#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_JNI_H_
#include <jni.h>
+
#include <string>
+
#include "utils/java/jni-base.h"
#ifndef TC3_LANG_ID_CLASS_NAME
@@ -63,6 +65,9 @@
TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetMinTextSizeInBytes)
(JNIEnv* env, jobject thizz, jlong ptr);
+TC3_JNI_METHOD(jint, TC3_LANG_ID_CLASS_NAME, nativeGetVersionWithOffset)
+(JNIEnv* env, jobject clazz, jint fd, jlong offset, jlong size);
+
#ifdef __cplusplus
}
#endif
diff --git a/native/lang_id/script/tiny-script-detector.h b/native/lang_id/script/tiny-script-detector.h
index a55da04..d08270c 100644
--- a/native/lang_id/script/tiny-script-detector.h
+++ b/native/lang_id/script/tiny-script-detector.h
@@ -74,12 +74,12 @@
// CPU, so it's better to use than int32.
static const unsigned int kGreekStart = 0x370;
- // Commented out (unsued in the code): kGreekEnd = 0x3FF;
+ // Commented out (unused in the code): kGreekEnd = 0x3FF;
static const unsigned int kCyrillicStart = 0x400;
static const unsigned int kCyrillicEnd = 0x4FF;
static const unsigned int kHebrewStart = 0x590;
- // Commented out (unsued in the code): kHebrewEnd = 0x5FF;
+ // Commented out (unused in the code): kHebrewEnd = 0x5FF;
static const unsigned int kArabicStart = 0x600;
static const unsigned int kArabicEnd = 0x6FF;
const unsigned int codepoint = ((p[0] & 0x1F) << 6) | (p[1] & 0x3F);
@@ -117,7 +117,7 @@
static const unsigned int kHiraganaStart = 0x3041;
static const unsigned int kHiraganaEnd = 0x309F;
- // Commented out (unsued in the code): kKatakanaStart = 0x30A0;
+ // Commented out (unused in the code): kKatakanaStart = 0x30A0;
static const unsigned int kKatakanaEnd = 0x30FF;
const unsigned int codepoint =
((p[0] & 0x0F) << 12) | ((p[1] & 0x3F) << 6) | (p[2] & 0x3F);
diff --git a/native/utils/grammar/rules.fbs b/native/utils/grammar/rules.fbs
index 3225892..021fe0f 100755
--- a/native/utils/grammar/rules.fbs
+++ b/native/utils/grammar/rules.fbs
@@ -14,9 +14,9 @@
// limitations under the License.
//
-include "utils/i18n/language-tag.fbs";
-include "utils/zlib/buffer.fbs";
include "utils/grammar/semantics/expression.fbs";
+include "utils/zlib/buffer.fbs";
+include "utils/i18n/language-tag.fbs";
// The terminal rules map as sorted strings table.
// The sorted terminal strings table is represented as offsets into the
diff --git a/native/utils/i18n/locale-list.cc b/native/utils/i18n/locale-list.cc
new file mode 100644
index 0000000..a0be5ac
--- /dev/null
+++ b/native/utils/i18n/locale-list.cc
@@ -0,0 +1,44 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/i18n/locale-list.h"
+
+#include <string>
+
+namespace libtextclassifier3 {
+
+LocaleList LocaleList::ParseFrom(const std::string& locale_tags) {
+ std::vector<StringPiece> split_locales = strings::Split(locale_tags, ',');
+ std::string reference_locale;
+ if (!split_locales.empty()) {
+ // Assigns the first parsed locale to reference_locale.
+ reference_locale = split_locales[0].ToString();
+ } else {
+ reference_locale = "";
+ }
+ std::vector<Locale> locales;
+ for (const StringPiece& locale_str : split_locales) {
+ const Locale locale = Locale::FromBCP47(locale_str.ToString());
+ if (!locale.IsValid()) {
+ TC3_LOG(WARNING) << "Failed to parse the detected_text_language_tag: "
+ << locale_str.ToString();
+ }
+ locales.push_back(locale);
+ }
+ return LocaleList(locales, split_locales, reference_locale);
+}
+
+} // namespace libtextclassifier3
diff --git a/native/utils/i18n/locale-list.h b/native/utils/i18n/locale-list.h
new file mode 100644
index 0000000..cf2e06d
--- /dev/null
+++ b/native/utils/i18n/locale-list.h
@@ -0,0 +1,55 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
+#define LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
+
+#include <string>
+
+#include "utils/i18n/locale.h"
+#include "utils/strings/split.h"
+
+namespace libtextclassifier3 {
+
+// Parses and hold data about locales (combined by delimiter ',').
+class LocaleList {
+ public:
+ // Constructs the
+ // - Collection of locale tag from local_tags
+ // - Collection of Locale objects from a valid BCP47 tag. (If the tag is
+ // invalid, an object is created but return false for IsInvalid() call.
+ // - Assigns the first parsed locale to reference_locale.
+ static LocaleList ParseFrom(const std::string& locale_tags);
+
+ std::vector<Locale> GetLocales() const { return locales_; }
+ std::vector<StringPiece> GetLocaleTags() const { return split_locales_; }
+ std::string GetReferenceLocale() const { return reference_locale_; }
+
+ private:
+ LocaleList(const std::vector<Locale>& locales,
+ const std::vector<StringPiece>& split_locales,
+ const StringPiece& reference_locale)
+ : locales_(locales),
+ split_locales_(split_locales),
+ reference_locale_(reference_locale.ToString()) {}
+
+ const std::vector<Locale> locales_;
+ const std::vector<StringPiece> split_locales_;
+ const std::string reference_locale_;
+};
+} // namespace libtextclassifier3
+
+#endif // LIBTEXTCLASSIFIER_UTILS_I18N_LOCALE_LIST_H_
diff --git a/native/utils/i18n/locale-list_test.cc b/native/utils/i18n/locale-list_test.cc
new file mode 100644
index 0000000..d7cfd17
--- /dev/null
+++ b/native/utils/i18n/locale-list_test.cc
@@ -0,0 +1,51 @@
+/*
+ * Copyright (C) 2018 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "utils/i18n/locale-list.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using ::testing::SizeIs;
+
+namespace libtextclassifier3 {
+namespace {
+
+TEST(LocaleTest, ParsedLocalesSanityCheck) {
+ LocaleList locale_list = LocaleList::ParseFrom("en-US,zh-CN,ar,en");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(4));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(4));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "en-US");
+}
+
+TEST(LocaleTest, ParsedLocalesEmpty) {
+ LocaleList locale_list = LocaleList::ParseFrom("");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(0));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(0));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "");
+}
+
+TEST(LocaleTest, ParsedLocalesIvalid) {
+ LocaleList locale_list = LocaleList::ParseFrom("en,invalid");
+ EXPECT_THAT(locale_list.GetLocales(), SizeIs(2));
+ EXPECT_THAT(locale_list.GetLocaleTags(), SizeIs(2));
+ EXPECT_EQ(locale_list.GetReferenceLocale(), "en");
+ EXPECT_TRUE(locale_list.GetLocales()[0].IsValid());
+ EXPECT_FALSE(locale_list.GetLocales()[1].IsValid());
+}
+
+} // namespace
+} // namespace libtextclassifier3
diff --git a/native/utils/tokenizer.cc b/native/utils/tokenizer.cc
index da66ff6..20f72c4 100644
--- a/native/utils/tokenizer.cc
+++ b/native/utils/tokenizer.cc
@@ -50,6 +50,10 @@
SortCodepointRanges(internal_tokenizer_codepoint_ranges,
&internal_tokenizer_codepoint_ranges_);
+ if (type_ == TokenizationType_MIXED && split_on_script_change) {
+ TC3_LOG(ERROR) << "The option `split_on_script_change` is unavailable for "
+ "the selected tokenizer type (mixed).";
+ }
}
const TokenizationCodepointRangeT* Tokenizer::FindTokenizationRange(
@@ -234,15 +238,20 @@
if (!break_iterator) {
return false;
}
+ const int context_unicode_size = context_unicode.size_codepoints();
int last_unicode_index = 0;
int unicode_index = 0;
auto token_begin_it = context_unicode.begin();
while ((unicode_index = break_iterator->Next()) !=
UniLib::BreakIterator::kDone) {
const int token_length = unicode_index - last_unicode_index;
+ if (token_length + last_unicode_index > context_unicode_size) {
+ return false;
+ }
auto token_end_it = token_begin_it;
std::advance(token_end_it, token_length);
+ TC3_CHECK(token_end_it <= context_unicode.end());
// Determine if the whole token is whitespace.
bool is_whitespace = true;
diff --git a/native/utils/utf8/unilib-javaicu.cc b/native/utils/utf8/unilib-javaicu.cc
index e56f979..befe639 100644
--- a/native/utils/utf8/unilib-javaicu.cc
+++ b/native/utils/utf8/unilib-javaicu.cc
@@ -25,6 +25,7 @@
#include "utils/base/logging.h"
#include "utils/base/statusor.h"
#include "utils/java/jni-base.h"
+#include "utils/java/jni-helper.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
@@ -80,6 +81,20 @@
// Implementations that call out to JVM. Behold the beauty.
// -----------------------------------------------------------------------------
+StatusOr<int32> UniLibBase::Length(const UnicodeText& text) const {
+ TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jstring> text_java,
+ jni_cache_->ConvertToJavaString(text));
+
+ JNIEnv* jenv = jni_cache_->GetEnv();
+ TC3_ASSIGN_OR_RETURN(int utf16_length,
+ JniHelper::CallIntMethod(jenv, text_java.get(),
+ jni_cache_->string_length));
+
+ return JniHelper::CallIntMethod(jenv, text_java.get(),
+ jni_cache_->string_code_point_count, 0,
+ utf16_length);
+}
+
bool UniLibBase::ParseInt32(const UnicodeText& text, int32* result) const {
return ParseInt(text, result);
}
diff --git a/native/utils/utf8/unilib-javaicu.h b/native/utils/utf8/unilib-javaicu.h
index 3d645b5..8b04789 100644
--- a/native/utils/utf8/unilib-javaicu.h
+++ b/native/utils/utf8/unilib-javaicu.h
@@ -57,6 +57,8 @@
char32 ToUpper(char32 codepoint) const;
char32 GetPairedBracket(char32 codepoint) const;
+ StatusOr<int32> Length(const UnicodeText& text) const;
+
// Forward declaration for friend.
class RegexPattern;
diff --git a/native/utils/utf8/unilib.h b/native/utils/utf8/unilib.h
index c9a3461..ffda7d9 100644
--- a/native/utils/utf8/unilib.h
+++ b/native/utils/utf8/unilib.h
@@ -152,6 +152,31 @@
bool IsLetter(char32 codepoint) const {
return libtextclassifier3::IsLetter(codepoint);
}
+
+ bool IsValidUtf8(const UnicodeText& text) const {
+ // Basic check of structural validity of UTF8.
+ if (!text.is_valid()) {
+ return false;
+ }
+ // In addition to that, we declare that a valid UTF8 is when the number of
+ // codepoints in the string as measured by ICU is the same as the number of
+ // codepoints as measured by UnicodeText. Because if we don't do this check,
+ // the indices might differ, and cause trouble, because the assumption
+ // throughout the code is that ICU indices and UnicodeText indices are the
+ // same.
+ // NOTE: This is not perfect, as this doesn't check the alignment of the
+ // codepoints, but for the practical purposes should be enough.
+ const StatusOr<int32> icu_length = Length(text);
+ if (!icu_length.ok()) {
+ return false;
+ }
+
+ if (icu_length.ValueOrDie() != text.size_codepoints()) {
+ return false;
+ }
+
+ return true;
+ }
};
} // namespace libtextclassifier3
diff --git a/native/utils/utf8/unilib_test-include.cc b/native/utils/utf8/unilib_test-include.cc
index 675ea26..ed0f184 100644
--- a/native/utils/utf8/unilib_test-include.cc
+++ b/native/utils/utf8/unilib_test-include.cc
@@ -527,5 +527,22 @@
UTF8ToUnicodeText("Information", /*do_copy=*/false), &result));
}
+TEST_F(UniLibTest, Length) {
+ EXPECT_EQ(unilib_->Length(UTF8ToUnicodeText("hello", /*do_copy=*/false))
+ .ValueOrDie(),
+ 5);
+ EXPECT_EQ(unilib_->Length(UTF8ToUnicodeText("ěščřž", /*do_copy=*/false))
+ .ValueOrDie(),
+ 5);
+ // Test Invalid UTF8.
+ // This testing condition needs to be != 1, as Apple character counting seems
+ // to return 0 when the input is invalid UTF8, while ICU will treat the
+ // invalid codepoint as 3 separate bytes.
+ EXPECT_NE(
+ unilib_->Length(UTF8ToUnicodeText("\xed\xa0\x80", /*do_copy=*/false))
+ .ValueOrDie(),
+ 1);
+}
+
} // namespace test_internal
} // namespace libtextclassifier3