| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| // JNI wrapper for the TextClassifier. |
| |
| #include "textclassifier_jni.h" |
| |
| #include <jni.h> |
| #include <type_traits> |
| #include <vector> |
| |
| #include "text-classifier.h" |
| #include "util/base/integral_types.h" |
| #include "util/java/scoped_local_ref.h" |
| #include "util/java/string_utils.h" |
| #include "util/memory/mmap.h" |
| #include "util/utf8/unilib.h" |
| |
| using libtextclassifier2::AnnotatedSpan; |
| using libtextclassifier2::AnnotationOptions; |
| using libtextclassifier2::ClassificationOptions; |
| using libtextclassifier2::ClassificationResult; |
| using libtextclassifier2::CodepointSpan; |
| using libtextclassifier2::JStringToUtf8String; |
| using libtextclassifier2::Model; |
| using libtextclassifier2::ScopedLocalRef; |
| using libtextclassifier2::SelectionOptions; |
| using libtextclassifier2::TextClassifier; |
| #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU |
| using libtextclassifier2::UniLib; |
| #endif |
| |
| namespace libtextclassifier2 { |
| |
| using libtextclassifier2::CodepointSpan; |
| |
| namespace { |
| |
| std::string ToStlString(JNIEnv* env, const jstring& str) { |
| std::string result; |
| JStringToUtf8String(env, str, &result); |
| return result; |
| } |
| |
| jobjectArray ClassificationResultsToJObjectArray( |
| JNIEnv* env, |
| const std::vector<ClassificationResult>& classification_result) { |
| const ScopedLocalRef<jclass> result_class( |
| env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult"), |
| env); |
| if (!result_class) { |
| TC_LOG(ERROR) << "Couldn't find ClassificationResult class."; |
| return nullptr; |
| } |
| const ScopedLocalRef<jclass> datetime_parse_class( |
| env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$DatetimeResult"), env); |
| if (!datetime_parse_class) { |
| TC_LOG(ERROR) << "Couldn't find DatetimeResult class."; |
| return nullptr; |
| } |
| |
| const jmethodID result_class_constructor = |
| env->GetMethodID(result_class.get(), "<init>", |
| "(Ljava/lang/String;FL" TC_PACKAGE_PATH TC_CLASS_NAME_STR |
| "$DatetimeResult;)V"); |
| const jmethodID datetime_parse_class_constructor = |
| env->GetMethodID(datetime_parse_class.get(), "<init>", "(JI)V"); |
| |
| const jobjectArray results = env->NewObjectArray(classification_result.size(), |
| result_class.get(), nullptr); |
| for (int i = 0; i < classification_result.size(); i++) { |
| jstring row_string = |
| env->NewStringUTF(classification_result[i].collection.c_str()); |
| jobject row_datetime_parse = nullptr; |
| if (classification_result[i].datetime_parse_result.IsSet()) { |
| row_datetime_parse = env->NewObject( |
| datetime_parse_class.get(), datetime_parse_class_constructor, |
| classification_result[i].datetime_parse_result.time_ms_utc, |
| classification_result[i].datetime_parse_result.granularity); |
| } |
| jobject result = |
| env->NewObject(result_class.get(), result_class_constructor, row_string, |
| static_cast<jfloat>(classification_result[i].score), |
| row_datetime_parse); |
| env->SetObjectArrayElement(results, i, result); |
| env->DeleteLocalRef(result); |
| } |
| return results; |
| } |
| |
| template <typename T, typename F> |
| std::pair<bool, T> CallJniMethod0(JNIEnv* env, jobject object, |
| jclass class_object, F function, |
| const std::string& method_name, |
| const std::string& return_java_type) { |
| const jmethodID method = env->GetMethodID(class_object, method_name.c_str(), |
| ("()" + return_java_type).c_str()); |
| if (!method) { |
| return std::make_pair(false, T()); |
| } |
| return std::make_pair(true, (env->*function)(object, method)); |
| } |
| |
| SelectionOptions FromJavaSelectionOptions(JNIEnv* env, jobject joptions) { |
| if (!joptions) { |
| return {}; |
| } |
| |
| const ScopedLocalRef<jclass> options_class( |
| env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$SelectionOptions"), |
| env); |
| const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( |
| env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, |
| "getLocales", "Ljava/lang/String;"); |
| if (!status_or_locales.first) { |
| return {}; |
| } |
| |
| SelectionOptions options; |
| options.locales = |
| ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); |
| |
| return options; |
| } |
| |
| template <typename T> |
| T FromJavaOptionsInternal(JNIEnv* env, jobject joptions, |
| const std::string& class_name) { |
| if (!joptions) { |
| return {}; |
| } |
| |
| const ScopedLocalRef<jclass> options_class(env->FindClass(class_name.c_str()), |
| env); |
| if (!options_class) { |
| return {}; |
| } |
| |
| const std::pair<bool, jobject> status_or_locales = CallJniMethod0<jobject>( |
| env, joptions, options_class.get(), &JNIEnv::CallObjectMethod, |
| "getLocale", "Ljava/lang/String;"); |
| const std::pair<bool, jobject> status_or_reference_timezone = |
| CallJniMethod0<jobject>(env, joptions, options_class.get(), |
| &JNIEnv::CallObjectMethod, "getReferenceTimezone", |
| "Ljava/lang/String;"); |
| const std::pair<bool, int64> status_or_reference_time_ms_utc = |
| CallJniMethod0<int64>(env, joptions, options_class.get(), |
| &JNIEnv::CallLongMethod, "getReferenceTimeMsUtc", |
| "J"); |
| |
| if (!status_or_locales.first || !status_or_reference_timezone.first || |
| !status_or_reference_time_ms_utc.first) { |
| return {}; |
| } |
| |
| T options; |
| options.locales = |
| ToStlString(env, reinterpret_cast<jstring>(status_or_locales.second)); |
| options.reference_timezone = ToStlString( |
| env, reinterpret_cast<jstring>(status_or_reference_timezone.second)); |
| options.reference_time_ms_utc = status_or_reference_time_ms_utc.second; |
| return options; |
| } |
| |
| ClassificationOptions FromJavaClassificationOptions(JNIEnv* env, |
| jobject joptions) { |
| return FromJavaOptionsInternal<ClassificationOptions>( |
| env, joptions, |
| TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationOptions"); |
| } |
| |
| AnnotationOptions FromJavaAnnotationOptions(JNIEnv* env, jobject joptions) { |
| return FromJavaOptionsInternal<AnnotationOptions>( |
| env, joptions, TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotationOptions"); |
| } |
| |
| CodepointSpan ConvertIndicesBMPUTF8(const std::string& utf8_str, |
| CodepointSpan orig_indices, |
| bool from_utf8) { |
| const libtextclassifier2::UnicodeText unicode_str = |
| libtextclassifier2::UTF8ToUnicodeText(utf8_str, /*do_copy=*/false); |
| |
| int unicode_index = 0; |
| int bmp_index = 0; |
| |
| const int* source_index; |
| const int* target_index; |
| if (from_utf8) { |
| source_index = &unicode_index; |
| target_index = &bmp_index; |
| } else { |
| source_index = &bmp_index; |
| target_index = &unicode_index; |
| } |
| |
| CodepointSpan result{-1, -1}; |
| std::function<void()> assign_indices_fn = [&result, &orig_indices, |
| &source_index, &target_index]() { |
| if (orig_indices.first == *source_index) { |
| result.first = *target_index; |
| } |
| |
| if (orig_indices.second == *source_index) { |
| result.second = *target_index; |
| } |
| }; |
| |
| for (auto it = unicode_str.begin(); it != unicode_str.end(); |
| ++it, ++unicode_index, ++bmp_index) { |
| assign_indices_fn(); |
| |
| // There is 1 extra character in the input for each UTF8 character > 0xFFFF. |
| if (*it > 0xFFFF) { |
| ++bmp_index; |
| } |
| } |
| assign_indices_fn(); |
| |
| return result; |
| } |
| |
| } // namespace |
| |
| CodepointSpan ConvertIndicesBMPToUTF8(const std::string& utf8_str, |
| CodepointSpan bmp_indices) { |
| return ConvertIndicesBMPUTF8(utf8_str, bmp_indices, /*from_utf8=*/false); |
| } |
| |
| CodepointSpan ConvertIndicesUTF8ToBMP(const std::string& utf8_str, |
| CodepointSpan utf8_indices) { |
| return ConvertIndicesBMPUTF8(utf8_str, utf8_indices, /*from_utf8=*/true); |
| } |
| |
| jint GetFdFromAssetFileDescriptor(JNIEnv* env, jobject afd) { |
| // Get system-level file descriptor from AssetFileDescriptor. |
| ScopedLocalRef<jclass> afd_class( |
| env->FindClass("android/content/res/AssetFileDescriptor"), env); |
| if (afd_class == nullptr) { |
| TC_LOG(ERROR) << "Couldn't find AssetFileDescriptor."; |
| return reinterpret_cast<jlong>(nullptr); |
| } |
| jmethodID afd_class_getFileDescriptor = env->GetMethodID( |
| afd_class.get(), "getFileDescriptor", "()Ljava/io/FileDescriptor;"); |
| if (afd_class_getFileDescriptor == nullptr) { |
| TC_LOG(ERROR) << "Couldn't find getFileDescriptor."; |
| return reinterpret_cast<jlong>(nullptr); |
| } |
| |
| ScopedLocalRef<jclass> fd_class(env->FindClass("java/io/FileDescriptor"), |
| env); |
| if (fd_class == nullptr) { |
| TC_LOG(ERROR) << "Couldn't find FileDescriptor."; |
| return reinterpret_cast<jlong>(nullptr); |
| } |
| jfieldID fd_class_descriptor = |
| env->GetFieldID(fd_class.get(), "descriptor", "I"); |
| if (fd_class_descriptor == nullptr) { |
| TC_LOG(ERROR) << "Couldn't find descriptor."; |
| return reinterpret_cast<jlong>(nullptr); |
| } |
| |
| jobject bundle_jfd = env->CallObjectMethod(afd, afd_class_getFileDescriptor); |
| return env->GetIntField(bundle_jfd, fd_class_descriptor); |
| } |
| |
| jstring GetLocalesFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { |
| if (!mmap->handle().ok()) { |
| return env->NewStringUTF(""); |
| } |
| const Model* model = libtextclassifier2::ViewModel( |
| mmap->handle().start(), mmap->handle().num_bytes()); |
| if (!model || !model->locales()) { |
| return env->NewStringUTF(""); |
| } |
| return env->NewStringUTF(model->locales()->c_str()); |
| } |
| |
| jint GetVersionFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { |
| if (!mmap->handle().ok()) { |
| return 0; |
| } |
| const Model* model = libtextclassifier2::ViewModel( |
| mmap->handle().start(), mmap->handle().num_bytes()); |
| if (!model) { |
| return 0; |
| } |
| return model->version(); |
| } |
| |
| jstring GetNameFromMmap(JNIEnv* env, libtextclassifier2::ScopedMmap* mmap) { |
| if (!mmap->handle().ok()) { |
| return env->NewStringUTF(""); |
| } |
| const Model* model = libtextclassifier2::ViewModel( |
| mmap->handle().start(), mmap->handle().num_bytes()); |
| if (!model || !model->name()) { |
| return env->NewStringUTF(""); |
| } |
| return env->NewStringUTF(model->name()->c_str()); |
| } |
| |
| } // namespace libtextclassifier2 |
| |
| using libtextclassifier2::ClassificationResultsToJObjectArray; |
| using libtextclassifier2::ConvertIndicesBMPToUTF8; |
| using libtextclassifier2::ConvertIndicesUTF8ToBMP; |
| using libtextclassifier2::FromJavaAnnotationOptions; |
| using libtextclassifier2::FromJavaClassificationOptions; |
| using libtextclassifier2::FromJavaSelectionOptions; |
| using libtextclassifier2::ToStlString; |
| |
| JNI_METHOD(jlong, TC_CLASS_NAME, nativeNew) |
| (JNIEnv* env, jobject thiz, jint fd) { |
| #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU |
| return reinterpret_cast<jlong>( |
| TextClassifier::FromFileDescriptor(fd).release(), new UniLib(env)); |
| #else |
| return reinterpret_cast<jlong>( |
| TextClassifier::FromFileDescriptor(fd).release()); |
| #endif |
| } |
| |
| JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromPath) |
| (JNIEnv* env, jobject thiz, jstring path) { |
| const std::string path_str = ToStlString(env, path); |
| #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU |
| return reinterpret_cast<jlong>( |
| TextClassifier::FromPath(path_str, new UniLib(env)).release()); |
| #else |
| return reinterpret_cast<jlong>(TextClassifier::FromPath(path_str).release()); |
| #endif |
| } |
| |
| JNI_METHOD(jlong, TC_CLASS_NAME, nativeNewFromAssetFileDescriptor) |
| (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { |
| const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); |
| #ifdef LIBTEXTCLASSIFIER_UNILIB_JAVAICU |
| return reinterpret_cast<jlong>( |
| TextClassifier::FromFileDescriptor(fd, offset, size, new UniLib(env)) |
| .release()); |
| #else |
| return reinterpret_cast<jlong>( |
| TextClassifier::FromFileDescriptor(fd, offset, size).release()); |
| #endif |
| } |
| |
| JNI_METHOD(jintArray, TC_CLASS_NAME, nativeSuggestSelection) |
| (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, |
| jint selection_end, jobject options) { |
| if (!ptr) { |
| return nullptr; |
| } |
| |
| TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); |
| |
| const std::string context_utf8 = ToStlString(env, context); |
| CodepointSpan input_indices = |
| ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); |
| CodepointSpan selection = model->SuggestSelection( |
| context_utf8, input_indices, FromJavaSelectionOptions(env, options)); |
| selection = ConvertIndicesUTF8ToBMP(context_utf8, selection); |
| |
| jintArray result = env->NewIntArray(2); |
| env->SetIntArrayRegion(result, 0, 1, &(std::get<0>(selection))); |
| env->SetIntArrayRegion(result, 1, 1, &(std::get<1>(selection))); |
| return result; |
| } |
| |
| JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeClassifyText) |
| (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jint selection_begin, |
| jint selection_end, jobject options) { |
| if (!ptr) { |
| return nullptr; |
| } |
| TextClassifier* ff_model = reinterpret_cast<TextClassifier*>(ptr); |
| |
| const std::string context_utf8 = ToStlString(env, context); |
| const CodepointSpan input_indices = |
| ConvertIndicesBMPToUTF8(context_utf8, {selection_begin, selection_end}); |
| const std::vector<ClassificationResult> classification_result = |
| ff_model->ClassifyText(context_utf8, input_indices, |
| FromJavaClassificationOptions(env, options)); |
| |
| return ClassificationResultsToJObjectArray(env, classification_result); |
| } |
| |
| JNI_METHOD(jobjectArray, TC_CLASS_NAME, nativeAnnotate) |
| (JNIEnv* env, jobject thiz, jlong ptr, jstring context, jobject options) { |
| if (!ptr) { |
| return nullptr; |
| } |
| TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); |
| std::string context_utf8 = ToStlString(env, context); |
| std::vector<AnnotatedSpan> annotations = |
| model->Annotate(context_utf8, FromJavaAnnotationOptions(env, options)); |
| |
| jclass result_class = |
| env->FindClass(TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"); |
| if (!result_class) { |
| TC_LOG(ERROR) << "Couldn't find result class: " |
| << TC_PACKAGE_PATH TC_CLASS_NAME_STR "$AnnotatedSpan"; |
| return nullptr; |
| } |
| |
| jmethodID result_class_constructor = env->GetMethodID( |
| result_class, "<init>", |
| "(II[L" TC_PACKAGE_PATH TC_CLASS_NAME_STR "$ClassificationResult;)V"); |
| |
| jobjectArray results = |
| env->NewObjectArray(annotations.size(), result_class, nullptr); |
| |
| for (int i = 0; i < annotations.size(); ++i) { |
| CodepointSpan span_bmp = |
| ConvertIndicesUTF8ToBMP(context_utf8, annotations[i].span); |
| jobject result = env->NewObject( |
| result_class, result_class_constructor, |
| static_cast<jint>(span_bmp.first), static_cast<jint>(span_bmp.second), |
| ClassificationResultsToJObjectArray(env, |
| |
| annotations[i].classification)); |
| env->SetObjectArrayElement(results, i, result); |
| env->DeleteLocalRef(result); |
| } |
| env->DeleteLocalRef(result_class); |
| return results; |
| } |
| |
| JNI_METHOD(void, TC_CLASS_NAME, nativeClose) |
| (JNIEnv* env, jobject thiz, jlong ptr) { |
| TextClassifier* model = reinterpret_cast<TextClassifier*>(ptr); |
| delete model; |
| } |
| |
| JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLanguage) |
| (JNIEnv* env, jobject clazz, jint fd) { |
| TC_LOG(WARNING) << "Using deprecated getLanguage()."; |
| return JNI_METHOD_NAME(TC_CLASS_NAME, nativeGetLocales)(env, clazz, fd); |
| } |
| |
| JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocales) |
| (JNIEnv* env, jobject clazz, jint fd) { |
| const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( |
| new libtextclassifier2::ScopedMmap(fd)); |
| return GetLocalesFromMmap(env, mmap.get()); |
| } |
| |
| JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetLocalesFromAssetFileDescriptor) |
| (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { |
| const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); |
| const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( |
| new libtextclassifier2::ScopedMmap(fd, offset, size)); |
| return GetLocalesFromMmap(env, mmap.get()); |
| } |
| |
| JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersion) |
| (JNIEnv* env, jobject clazz, jint fd) { |
| const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( |
| new libtextclassifier2::ScopedMmap(fd)); |
| return GetVersionFromMmap(env, mmap.get()); |
| } |
| |
| JNI_METHOD(jint, TC_CLASS_NAME, nativeGetVersionFromAssetFileDescriptor) |
| (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { |
| const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); |
| const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( |
| new libtextclassifier2::ScopedMmap(fd, offset, size)); |
| return GetVersionFromMmap(env, mmap.get()); |
| } |
| |
| JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetName) |
| (JNIEnv* env, jobject clazz, jint fd) { |
| const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( |
| new libtextclassifier2::ScopedMmap(fd)); |
| return GetNameFromMmap(env, mmap.get()); |
| } |
| |
| JNI_METHOD(jstring, TC_CLASS_NAME, nativeGetNameFromAssetFileDescriptor) |
| (JNIEnv* env, jobject thiz, jobject afd, jlong offset, jlong size) { |
| const jint fd = libtextclassifier2::GetFdFromAssetFileDescriptor(env, afd); |
| const std::unique_ptr<libtextclassifier2::ScopedMmap> mmap( |
| new libtextclassifier2::ScopedMmap(fd, offset, size)); |
| return GetNameFromMmap(env, mmap.get()); |
| } |