blob: f064a63dbfcab70b3ec30b60ccbb5028774f1987 [file] [log] [blame]
Lukas Zilka21d8c982018-01-24 11:11:20 +01001/*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#include "cached-features.h"
18
19#include "model-executor.h"
20#include "tensor-view.h"
21
22#include "gmock/gmock.h"
23#include "gtest/gtest.h"
24
25using testing::ElementsAreArray;
26using testing::FloatEq;
27using testing::Matcher;
28
29namespace libtextclassifier2 {
30namespace {
31
32Matcher<std::vector<float>> ElementsAreFloat(const std::vector<float>& values) {
33 std::vector<Matcher<float>> matchers;
34 for (const float value : values) {
35 matchers.push_back(FloatEq(value));
36 }
37 return ElementsAreArray(matchers);
38}
39
Lukas Zilkaba849e72018-03-08 14:48:21 +010040std::unique_ptr<std::vector<float>> MakeFeatures(int num_tokens) {
41 std::unique_ptr<std::vector<float>> features(new std::vector<float>());
42 for (int i = 1; i <= num_tokens; ++i) {
43 features->push_back(i * 11.0f);
44 features->push_back(-i * 11.0f);
45 features->push_back(i * 0.1f);
Lukas Zilka21d8c982018-01-24 11:11:20 +010046 }
Lukas Zilkaba849e72018-03-08 14:48:21 +010047 return features;
48}
Lukas Zilka21d8c982018-01-24 11:11:20 +010049
Lukas Zilkab23e2122018-02-09 10:25:19 +010050std::vector<float> GetCachedClickContextFeatures(
51 const CachedFeatures& cached_features, int click_pos) {
52 std::vector<float> output_features;
53 cached_features.AppendClickContextFeaturesForClick(click_pos,
54 &output_features);
55 return output_features;
56}
57
58std::vector<float> GetCachedBoundsSensitiveFeatures(
59 const CachedFeatures& cached_features, TokenSpan selected_span) {
60 std::vector<float> output_features;
61 cached_features.AppendBoundsSensitiveFeaturesForSpan(selected_span,
62 &output_features);
63 return output_features;
64}
65
66TEST(CachedFeaturesTest, ClickContext) {
67 FeatureProcessorOptionsT options;
68 options.context_size = 2;
69 options.feature_version = 1;
Lukas Zilka21d8c982018-01-24 11:11:20 +010070 flatbuffers::FlatBufferBuilder builder;
Lukas Zilkab23e2122018-02-09 10:25:19 +010071 builder.Finish(CreateFeatureProcessorOptions(builder, &options));
72 flatbuffers::DetachedBuffer options_fb = builder.Release();
Lukas Zilka21d8c982018-01-24 11:11:20 +010073
Lukas Zilkaba849e72018-03-08 14:48:21 +010074 std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
75 std::unique_ptr<std::vector<float>> padding_features(
76 new std::vector<float>{112233.0, -112233.0, 321.0});
Lukas Zilka21d8c982018-01-24 11:11:20 +010077
Lukas Zilkab23e2122018-02-09 10:25:19 +010078 const std::unique_ptr<CachedFeatures> cached_features =
79 CachedFeatures::Create(
Lukas Zilkaba849e72018-03-08 14:48:21 +010080 {3, 10}, std::move(features), std::move(padding_features),
Lukas Zilkab23e2122018-02-09 10:25:19 +010081 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
Lukas Zilkaba849e72018-03-08 14:48:21 +010082 /*feature_vector_size=*/3);
Lukas Zilkab23e2122018-02-09 10:25:19 +010083 ASSERT_TRUE(cached_features);
Lukas Zilka21d8c982018-01-24 11:11:20 +010084
Lukas Zilkab23e2122018-02-09 10:25:19 +010085 EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 5),
86 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0, -33.0,
87 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5}));
88
89 EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 6),
90 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0, -44.0,
91 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6}));
92
93 EXPECT_THAT(GetCachedClickContextFeatures(*cached_features, 7),
94 ElementsAreFloat({33.0, -33.0, 0.3, 44.0, -44.0, 0.4, 55.0, -55.0,
95 0.5, 66.0, -66.0, 0.6, 77.0, -77.0, 0.7}));
96}
97
98TEST(CachedFeaturesTest, BoundsSensitive) {
99 std::unique_ptr<FeatureProcessorOptions_::BoundsSensitiveFeaturesT> config(
100 new FeatureProcessorOptions_::BoundsSensitiveFeaturesT());
101 config->enabled = true;
102 config->num_tokens_before = 2;
103 config->num_tokens_inside_left = 2;
104 config->num_tokens_inside_right = 2;
105 config->num_tokens_after = 2;
106 config->include_inside_bag = true;
107 config->include_inside_length = true;
108 FeatureProcessorOptionsT options;
109 options.bounds_sensitive_features = std::move(config);
110 options.feature_version = 2;
111 flatbuffers::FlatBufferBuilder builder;
112 builder.Finish(CreateFeatureProcessorOptions(builder, &options));
113 flatbuffers::DetachedBuffer options_fb = builder.Release();
114
Lukas Zilkaba849e72018-03-08 14:48:21 +0100115 std::unique_ptr<std::vector<float>> features = MakeFeatures(9);
116 std::unique_ptr<std::vector<float>> padding_features(
117 new std::vector<float>{112233.0, -112233.0, 321.0});
Lukas Zilkab23e2122018-02-09 10:25:19 +0100118
Lukas Zilkab23e2122018-02-09 10:25:19 +0100119 const std::unique_ptr<CachedFeatures> cached_features =
120 CachedFeatures::Create(
Lukas Zilkaba849e72018-03-08 14:48:21 +0100121 {3, 9}, std::move(features), std::move(padding_features),
Lukas Zilkab23e2122018-02-09 10:25:19 +0100122 flatbuffers::GetRoot<FeatureProcessorOptions>(options_fb.data()),
Lukas Zilkaba849e72018-03-08 14:48:21 +0100123 /*feature_vector_size=*/3);
Lukas Zilkab23e2122018-02-09 10:25:19 +0100124 ASSERT_TRUE(cached_features);
Lukas Zilka21d8c982018-01-24 11:11:20 +0100125
126 EXPECT_THAT(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100127 GetCachedBoundsSensitiveFeatures(*cached_features, {5, 8}),
128 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
129 -33.0, 0.3, 44.0, -44.0, 0.4, 44.0, -44.0,
130 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
131 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4, 3.0}));
132
133 EXPECT_THAT(
134 GetCachedBoundsSensitiveFeatures(*cached_features, {5, 7}),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100135 ElementsAreFloat({11.0, -11.0, 0.1, 22.0, -22.0, 0.2, 33.0,
136 -33.0, 0.3, 44.0, -44.0, 0.4, 33.0, -33.0,
137 0.3, 44.0, -44.0, 0.4, 55.0, -55.0, 0.5,
Lukas Zilkab23e2122018-02-09 10:25:19 +0100138 66.0, -66.0, 0.6, 38.5, -38.5, 0.35, 2.0}));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100139
140 EXPECT_THAT(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100141 GetCachedBoundsSensitiveFeatures(*cached_features, {6, 8}),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100142 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3, 44.0,
143 -44.0, 0.4, 55.0, -55.0, 0.5, 44.0, -44.0,
144 0.4, 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
Lukas Zilkab23e2122018-02-09 10:25:19 +0100145 112233.0, -112233.0, 321.0, 49.5, -49.5, 0.45, 2.0}));
Lukas Zilka21d8c982018-01-24 11:11:20 +0100146
147 EXPECT_THAT(
Lukas Zilkab23e2122018-02-09 10:25:19 +0100148 GetCachedBoundsSensitiveFeatures(*cached_features, {6, 7}),
Lukas Zilka21d8c982018-01-24 11:11:20 +0100149 ElementsAreFloat({22.0, -22.0, 0.2, 33.0, -33.0, 0.3,
150 44.0, -44.0, 0.4, 112233.0, -112233.0, 321.0,
151 112233.0, -112233.0, 321.0, 44.0, -44.0, 0.4,
152 55.0, -55.0, 0.5, 66.0, -66.0, 0.6,
153 44.0, -44.0, 0.4, 1.0}));
154}
155
156} // namespace
157} // namespace libtextclassifier2