Sync from google3.
Bug: 68239358
Test: Builds. Tested on device. CTS test passes.
bit FrameworksCoreTests:android.view.textclassifier.TextClassificationManagerTest
Change-Id: Ie5e20b06b1c615ab246e7ed7f08e980e61c492c4
diff --git a/cached-features_test.cc b/cached-features_test.cc
index 2412ff3..9566a8d 100644
--- a/cached-features_test.cc
+++ b/cached-features_test.cc
@@ -54,18 +54,29 @@
std::vector<float> storage_;
};
-TEST(CachedFeaturesTest, Simple) {
- FeatureProcessorOptions_::BoundsSensitiveFeaturesT config;
- config.enabled = true;
- config.num_tokens_before = 2;
- config.num_tokens_inside_left = 2;
- config.num_tokens_inside_right = 2;
- config.num_tokens_after = 2;
- config.include_inside_bag = true;
- config.include_inside_length = true;
+std::vector<float> GetCachedClickContextFeatures(
+ const CachedFeatures& cached_features, int click_pos) {
+ std::vector<float> output_features;
+ cached_features.AppendClickContextFeaturesForClick(click_pos,
+ &output_features);
+ return output_features;
+}
+
+std::vector<float> GetCachedBoundsSensitiveFeatures(
+ const CachedFeatures& cached_features, TokenSpan selected_span) {
+ std::vector<float> output_features;
+ cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
+ &output_features);
+ return output_features;
+}
+
+TEST(CachedFeaturesTest, ClickContext) {
+ FeatureProcessorOptionsT options;
+ options.context_size = 2;
+ options.feature_version = 1;
flatbuffers::FlatBufferBuilder builder;
- builder.Finish(CreateBoundsSensitiveFeatures(builder, &config));
- flatbuffers::DetachedBuffer config_fb = builder.Release();
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ flatbuffers::DetachedBuffer options_fb = builder.Release();
std::vector<std::vector<int>> sparse_features(9);
for (int i = 0; i < sparse_features.size(); ++i) {
@@ -80,36 +91,88 @@
std::vector<float> padding_dense_features = {321.0};
FakeEmbeddingExecutor executor;
- const CachedFeatures cached_features(
- {3, 9}, sparse_features, dense_features, padding_sparse_features,
- padding_dense_features,
- flatbuffers::GetRoot<FeatureProcessorOptions_::BoundsSensitiveFeatures>(
- config_fb.data()),
- &executor, /*feature_vector_size=*/3);
+ const std::unique_ptr<CachedFeatures> cached_features =
+ CachedFeatures::Create(
+ {3, 10}, sparse_features, dense_features, padding_sparse_features,
+ padding_dense_features,
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &executor, /*feature_vector_size=*/3);
+ ASSERT_TRUE(cached_features);
- EXPECT_THAT(cached_features.Get({5, 8}),
- ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2,
- 33.0, -33.0, 0.3, 44.0, -44.0, 0.4,
- 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
- 66.0, -66.0, 0.6, 112233.0, -112233.0, 321.0,
- 132.0, -132.0, 1.2, 3.0}));
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
+ 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
+
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
+ ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
+
+ EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
+ ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
+ 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
+}
+
+TEST(CachedFeaturesTest, BoundsSensitive) {
+ std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
+ new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
+ config->enabled = true;
+ config->num_tokens_before = 2;
+ config->num_tokens_inside_left = 2;
+ config->num_tokens_inside_right = 2;
+ config->num_tokens_after = 2;
+ config->include_inside_bag = true;
+ config->include_inside_length = true;
+ FeatureProcessorOptionsT options;
+ options.bounds_sensitive_features = std::move(config);
+ options.feature_version = 2;
+ flatbuffers::FlatBufferBuilder builder;
+ builder.Finish(CreateFeatureProcessorOptions(builder, &options));
+ flatbuffers::DetachedBuffer options_fb = builder.Release();
+
+ std::vector<std::vector<int>> sparse_features(6);
+ for (int i = 0; i < sparse_features.size(); ++i) {
+ sparse_features[i].push_back(i + 1);
+ }
+ std::vector<std::vector<float>> dense_features(6);
+ for (int i = 0; i < dense_features.size(); ++i) {
+ dense_features[i].push_back((i + 1) * 0.1);
+ }
+
+ std::vector<int> padding_sparse_features = {10203};
+ std::vector<float> padding_dense_features = {321.0};
+
+ FakeEmbeddingExecutor executor;
+ const std::unique_ptr<CachedFeatures> cached_features =
+ CachedFeatures::Create(
+ {3, 9}, sparse_features, dense_features, padding_sparse_features,
+ padding_dense_features,
+ flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
+ &executor, /*feature_vector_size=*/3);
+ ASSERT_TRUE(cached_features);
EXPECT_THAT(
- cached_features.Get({5, 7}),
+ GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
+ ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
+ -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0,
+ 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
+ 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0}));
+
+ EXPECT_THAT(
+ GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
-33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
- 66.0, -66.0, 0.6, 77.0, -77.0, 0.7, 2.0}));
+ 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0}));
EXPECT_THAT(
- cached_features.Get({6, 8}),
+ GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
-44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
- 112233.0, -112233.0, 321.0, 99.0, -99.0, 0.9, 2.0}));
+ 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0}));
EXPECT_THAT(
- cached_features.Get({6, 7}),
+ GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,