blob: 12ded7d1cf46b3a15f1a7d853e85891a4d5f9769 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/intents/intent-generator.h"
#include <map>
#include "utils/base/logging.h"
#include "utils/java/string_utils.h"
#include "utils/lua-utils.h"
#include "utils/strings/stringpiece.h"
#include "utils/variant.h"
#ifdef __cplusplus
extern "C" {
#endif
#include "lauxlib.h"
#include "lua.h"
#ifdef __cplusplus
}
#endif
namespace libtextclassifier3 {
namespace {
static constexpr const char* kEntityTextKey = "text";
static constexpr const char* kTimeUsecKey = "parsed_time_ms_utc";
static constexpr const char* kReferenceTimeUsecKey = "reference_time_ms_utc";
// An Android specific Lua environment with JNI backed callbacks.
class JniLuaEnvironment : public LuaEnvironment {
public:
JniLuaEnvironment(const JniCache* jni_cache, const jobject context,
StringPiece entity_text, int64 event_time_ms_usec,
int64 reference_time_ms_utc,
const std::map<std::string, Variant>& extra);
// Runs an intent generator snippet.
std::vector<RemoteActionTemplate> RunIntentGenerator(
const std::string& generator_snippet);
protected:
int HandleCallback(int callback_id) override;
private:
// Callback handlers.
int HandleExternalCallback();
int HandleExtrasLookup();
int HandleAndroidCallback();
int HandleUserRestrictionsCallback();
int HandleUrlEncode();
int HandleUrlSchema();
// Reads and create a RemoteAction result from Lua.
RemoteActionTemplate ReadRemoteActionTemplateResult();
// Reads the extras from the Lua result.
void ReadExtras(std::map<std::string, Variant>* extra);
// Reads the intent categories array from a Lua result.
void ReadCategories(std::vector<std::string>* category);
// Retrieves user manager if not previously done.
bool RetrieveUserManager();
// Builtins.
enum CallbackId {
CALLBACK_ID_EXTERNAL = 0,
CALLBACK_ID_EXTRAS = 1,
CALLBACK_ID_ANDROID = 2,
CALLBACK_ID_USER_PERMISSIONS = 3,
CALLBACK_ID_URL_ENCODE = 4,
CALLBACK_ID_URL_SCHEMA = 5,
};
JNIEnv* jenv_;
const JniCache* jni_cache_;
jobject context_;
StringPiece entity_text_;
int64 event_time_ms_usec_;
int64 reference_time_ms_utc_;
const std::map<std::string, Variant>& extra_;
ScopedGlobalRef<jobject> usermanager_;
// Whether we previously attempted to retrieve the UserManager before.
bool usermanager_retrieved_;
};
JniLuaEnvironment::JniLuaEnvironment(
const JniCache* jni_cache, const jobject context, StringPiece entity_text,
int64 event_time_ms_usec, int64 reference_time_ms_utc,
const std::map<std::string, Variant>& extra)
: jenv_(jni_cache ? jni_cache->GetEnv() : nullptr),
jni_cache_(jni_cache),
context_(context),
entity_text_(entity_text),
event_time_ms_usec_(event_time_ms_usec),
reference_time_ms_utc_(reference_time_ms_utc),
extra_(extra),
usermanager_(/*object=*/nullptr,
/*jvm=*/(jni_cache ? jni_cache->jvm : nullptr)),
usermanager_retrieved_(false) {
LoadDefaultLibraries();
// Setup callbacks.
// This exposes an `external` object with the following fields:
// * extras: the bundle with all information about a classification.
// * android: callbacks into specific android provided methods.
// * android.user_restrictions: callbacks to check user permissions.
SetupTableLookupCallback("external", CALLBACK_ID_EXTERNAL);
// extras
lua_pushstring(state_, "extras");
SetupTableLookupCallback("extras", CALLBACK_ID_EXTRAS);
lua_settable(state_, -3);
// android
lua_pushstring(state_, "android");
SetupTableLookupCallback("android", CALLBACK_ID_ANDROID);
// android.user_restrictions
lua_pushstring(state_, "user_restrictions");
SetupTableLookupCallback("user_restrictions", CALLBACK_ID_USER_PERMISSIONS);
lua_settable(state_, -3);
lua_settable(state_, -3);
lua_setglobal(state_, "external");
}
int JniLuaEnvironment::HandleCallback(int callback_id) {
switch (callback_id) {
case CALLBACK_ID_EXTERNAL:
return HandleExternalCallback();
case CALLBACK_ID_EXTRAS:
return HandleExtrasLookup();
case CALLBACK_ID_ANDROID:
return HandleAndroidCallback();
case CALLBACK_ID_USER_PERMISSIONS:
return HandleUserRestrictionsCallback();
case CALLBACK_ID_URL_ENCODE:
return HandleUrlEncode();
case CALLBACK_ID_URL_SCHEMA:
return HandleUrlSchema();
default:
TC3_LOG(ERROR) << "Unhandled callback: " << callback_id;
return LUA_ERRRUN;
}
}
int JniLuaEnvironment::HandleExternalCallback() {
const char* key = luaL_checkstring(state_, 2);
if (strcmp(kReferenceTimeUsecKey, key) == 0) {
lua_pushinteger(state_, reference_time_ms_utc_);
return LUA_YIELD;
} else {
TC3_LOG(ERROR) << "Undefined external access " << key;
return LUA_ERRRUN;
}
}
int JniLuaEnvironment::HandleExtrasLookup() {
const char* key = luaL_checkstring(state_, 2);
if (strcmp(kEntityTextKey, key) == 0) {
lua_pushlstring(state_, entity_text_.data(), entity_text_.length());
} else if (strcmp(kTimeUsecKey, key) == 0) {
lua_pushinteger(state_, event_time_ms_usec_);
} else {
const auto it = extra_.find(std::string(key));
if (it == extra_.end()) {
TC3_LOG(ERROR) << "Undefined extra lookup " << key;
return LUA_ERRRUN;
}
PushValue(it->second);
}
return LUA_YIELD;
}
int JniLuaEnvironment::HandleAndroidCallback() {
const char* key = luaL_checkstring(state_, 2);
if (strcmp("package_name", key) == 0) {
ScopedLocalRef<jstring> package_name_str(
static_cast<jstring>(jenv_->CallObjectMethod(
context_, jni_cache_->context_get_package_name)));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling Context.getPackageName";
return LUA_ERRRUN;
}
ScopedStringChars package_name =
GetScopedStringChars(jenv_, package_name_str.get());
lua_pushstring(state_, reinterpret_cast<const char*>(package_name.get()));
return LUA_YIELD;
} else if (strcmp("urlencode", key) == 0) {
PushCallback(CALLBACK_ID_URL_ENCODE);
return LUA_YIELD;
} else if (strcmp("url_schema", key) == 0) {
PushCallback(CALLBACK_ID_URL_SCHEMA);
return LUA_YIELD;
} else {
TC3_LOG(ERROR) << "Undefined android reference " << key;
return LUA_ERRRUN;
}
}
int JniLuaEnvironment::HandleUserRestrictionsCallback() {
if (jni_cache_->usermanager_class == nullptr ||
jni_cache_->usermanager_get_user_restrictions == nullptr) {
// UserManager is only available for API level >= 17 and
// getUserRestrictions only for API level >= 18, so we just return false
// normally here.
lua_pushboolean(state_, false);
return LUA_YIELD;
}
// Get user manager if not previously retrieved.
if (!RetrieveUserManager()) {
TC3_LOG(ERROR) << "Error retrieving user manager.";
return LUA_ERRRUN;
}
ScopedLocalRef<jobject> bundle(jenv_->CallObjectMethod(
usermanager_.get(), jni_cache_->usermanager_get_user_restrictions));
if (jni_cache_->ExceptionCheckAndClear() || bundle == nullptr) {
TC3_LOG(ERROR) << "Error calling getUserRestrictions";
return LUA_ERRRUN;
}
ScopedLocalRef<jstring> key(jenv_->NewStringUTF(luaL_checkstring(state_, 2)));
if (key == nullptr) {
TC3_LOG(ERROR) << "Expected string, got null.";
return LUA_ERRRUN;
}
const bool permission = jenv_->CallBooleanMethod(
bundle.get(), jni_cache_->bundle_get_boolean, key.get());
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error getting bundle value";
lua_pushboolean(state_, false);
} else {
lua_pushboolean(state_, permission);
}
return LUA_YIELD;
}
int JniLuaEnvironment::HandleUrlEncode() {
// Call Java URL encoder.
ScopedLocalRef<jstring> input_str(
jenv_->NewStringUTF(luaL_checkstring(state_, 1)));
if (input_str == nullptr) {
TC3_LOG(ERROR) << "Expected string, got null.";
return LUA_ERRRUN;
}
ScopedLocalRef<jstring> encoding_str(jenv_->NewStringUTF("UTF-8"));
ScopedLocalRef<jstring> encoded_str(
static_cast<jstring>(jenv_->CallStaticObjectMethod(
jni_cache_->urlencoder_class.get(), jni_cache_->urlencoder_encode,
input_str.get(), encoding_str.get())));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling UrlEncoder.encode";
return LUA_ERRRUN;
}
ScopedStringChars encoded = GetScopedStringChars(jenv_, encoded_str.get());
lua_pushstring(state_, encoded.get());
return LUA_YIELD;
}
int JniLuaEnvironment::HandleUrlSchema() {
// Call to Java URI parser.
ScopedLocalRef<jstring> url_str(
jenv_->NewStringUTF(luaL_checkstring(state_, 1)));
if (url_str == nullptr) {
TC3_LOG(ERROR) << "Expected string, got null";
return LUA_ERRRUN;
}
// Try to parse uri and get scheme.
ScopedLocalRef<jobject> uri(jenv_->CallStaticObjectMethod(
jni_cache_->uri_class.get(), jni_cache_->uri_parse, url_str.get()));
if (jni_cache_->ExceptionCheckAndClear() || uri == nullptr) {
TC3_LOG(ERROR) << "Error calling Uri.parse";
return LUA_ERRRUN;
}
ScopedLocalRef<jstring> scheme_str(static_cast<jstring>(
jenv_->CallObjectMethod(uri.get(), jni_cache_->uri_get_scheme)));
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling Uri.getScheme";
return LUA_ERRRUN;
}
if (scheme_str == nullptr) {
lua_pushnil(state_);
} else {
ScopedStringChars scheme = GetScopedStringChars(jenv_, scheme_str.get());
lua_pushstring(state_, scheme.get());
}
return LUA_YIELD;
}
bool JniLuaEnvironment::RetrieveUserManager() {
if (context_ == nullptr) {
return false;
}
if (usermanager_retrieved_) {
return (usermanager_ != nullptr);
}
usermanager_retrieved_ = true;
ScopedLocalRef<jstring> service(jenv_->NewStringUTF("user"));
jobject usermanager_ref = jenv_->CallObjectMethod(
context_, jni_cache_->context_get_system_service, service.get());
if (jni_cache_->ExceptionCheckAndClear()) {
TC3_LOG(ERROR) << "Error calling getSystemService.";
return false;
}
usermanager_ = MakeGlobalRef(usermanager_ref, jenv_, jni_cache_->jvm);
return (usermanager_ != nullptr);
}
RemoteActionTemplate JniLuaEnvironment::ReadRemoteActionTemplateResult() {
RemoteActionTemplate result;
// Read intent template.
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
const char* key = lua_tostring(state_, /*idx=*/-2);
if (strcmp("title", key) == 0) {
result.title = lua_tostring(state_, /*idx=*/-1);
} else if (strcmp("description", key) == 0) {
result.description = lua_tostring(state_, /*idx=*/-1);
} else if (strcmp("action", key) == 0) {
result.action = lua_tostring(state_, /*idx=*/-1);
} else if (strcmp("data", key) == 0) {
result.data = lua_tostring(state_, /*idx=*/-1);
} else if (strcmp("type", key) == 0) {
result.type = lua_tostring(state_, /*idx=*/-1);
} else if (strcmp("flags", key) == 0) {
result.flags = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
} else if (strcmp("package_name", key) == 0) {
result.package_name = lua_tostring(state_, /*idx=*/-1);
} else if (strcmp("request_code", key) == 0) {
result.request_code = static_cast<int>(lua_tointeger(state_, /*idx=*/-1));
} else if (strcmp("category", key) == 0) {
ReadCategories(&result.category);
} else if (strcmp("extra", key) == 0) {
ReadExtras(&result.extra);
} else {
TC3_LOG(INFO) << "Unknown entry: " << key;
}
lua_pop(state_, 1);
}
lua_pop(state_, 1);
return result;
}
void JniLuaEnvironment::ReadCategories(std::vector<std::string>* category) {
// Read category array.
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected categories table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
return;
}
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
category->push_back(lua_tostring(state_, /*idx=*/-1));
lua_pop(state_, 1);
}
}
void JniLuaEnvironment::ReadExtras(std::map<std::string, Variant>* extra) {
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected extras table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
return;
}
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
// Each entry is a table specifying name and value.
// The value is specified via a type specific field as Lua doesn't allow
// to easily distinguish between different number types.
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected a table for an extra, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
return;
}
std::string name;
Variant value;
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
const char* key = lua_tostring(state_, /*idx=*/-2);
if (strcmp("name", key) == 0) {
name = std::string(lua_tostring(state_, /*idx=*/-1));
} else if (strcmp("int_value", key) == 0) {
value = Variant(static_cast<int>(lua_tonumber(state_, /*idx=*/-1)));
} else if (strcmp("long_value", key) == 0) {
value = Variant(static_cast<int64>(lua_tonumber(state_, /*idx=*/-1)));
} else if (strcmp("float_value", key) == 0) {
value = Variant(static_cast<float>(lua_tonumber(state_, /*idx=*/-1)));
} else if (strcmp("bool_value", key) == 0) {
value = Variant(static_cast<bool>(lua_toboolean(state_, /*idx=*/-1)));
} else if (strcmp("string_value", key) == 0) {
value = Variant(lua_tostring(state_, /*idx=*/-1));
} else {
TC3_LOG(INFO) << "Unknown extra field: " << key;
}
lua_pop(state_, 1);
}
if (!name.empty()) {
(*extra)[name] = value;
} else {
TC3_LOG(ERROR) << "Unnamed extra entry. Skipping.";
}
lua_pop(state_, 1);
}
}
} // namespace
std::vector<RemoteActionTemplate> JniLuaEnvironment::RunIntentGenerator(
const std::string& generator_snippet) {
int status = luaL_loadstring(state_, generator_snippet.data());
if (status != LUA_OK) {
TC3_LOG(ERROR) << "Couldn't load generator snippet: " << status;
return {};
}
status = lua_pcall(state_, /*nargs=*/0, /*nargs=*/1, /*errfunc=*/0);
if (status != LUA_OK) {
TC3_LOG(ERROR) << "Couldn't run generator snippet: " << status;
return {};
}
// Read result.
if (lua_gettop(state_) != 1 || lua_type(state_, 1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Unexpected result for snippet.";
return {};
}
// Read remote action templates array.
std::vector<RemoteActionTemplate> result;
lua_pushnil(state_);
while (lua_next(state_, /*idx=*/-2)) {
if (lua_type(state_, /*idx=*/-1) != LUA_TTABLE) {
TC3_LOG(ERROR) << "Expected intent table, got: "
<< lua_type(state_, /*idx=*/-1);
lua_pop(state_, 1);
continue;
}
result.push_back(ReadRemoteActionTemplateResult());
}
lua_pop(state_, /*n=*/1);
// Check that we correctly cleaned-up the state.
const int stack_size = lua_gettop(state_);
if (stack_size > 0) {
TC3_LOG(ERROR) << "Unexpected stack size.";
lua_settop(state_, 0);
return {};
}
return result;
}
IntentGenerator::IntentGenerator(const IntentFactoryModel* options,
const std::shared_ptr<JniCache>& jni_cache,
const jobject context)
: options_(options), jni_cache_(jni_cache), context_(context) {
if (options_ == nullptr || options_->entities() == nullptr) {
return;
}
// Normally this check would be performed by the Java compiler and we wouldn't
// need to worry about it here. But we can't depend on Android's SDK in Java,
// so we check the instance type here.
if (context != nullptr && !jni_cache->GetEnv()->IsInstanceOf(
context, jni_cache->context_class.get())) {
TC3_LOG(ERROR) << "Provided context is not an android.content.Context";
return;
}
if (options_ != nullptr && options_->entities() != nullptr) {
for (const IntentFactoryModel_::IntentGenerator* generator :
*options_->entities()) {
generators_[generator->entity_type()->str()] =
std::string(reinterpret_cast<const char*>(
generator->lua_template_generator()->data()),
generator->lua_template_generator()->size());
}
}
}
std::vector<RemoteActionTemplate> IntentGenerator::GenerateIntents(
const ClassificationResult& classification, int64 reference_time_ms_usec,
StringPiece entity_text) const {
if (options_ == nullptr) {
return {};
}
// Retrieve generator for specified entity.
auto it = generators_.find(classification.collection);
if (it == generators_.end()) {
TC3_LOG(INFO) << "Unknown entity: " << classification.collection;
return {};
}
std::unique_ptr<JniLuaEnvironment> interpreter(
new JniLuaEnvironment(jni_cache_.get(), context_, entity_text,
classification.datetime_parse_result.time_ms_utc,
reference_time_ms_usec, classification.extra));
return interpreter->RunIntentGenerator(it->second);
}
} // namespace libtextclassifier3