Export libtextclassifier am: c121edde42 am: 55ed12d074

Change-Id: Ie7d53265c9b0b005a5f835de8102c3b4d3c73831
diff --git a/native/actions/actions-suggestions.cc b/native/actions/actions-suggestions.cc
index a84f2cd..1fcd35c 100644
--- a/native/actions/actions-suggestions.cc
+++ b/native/actions/actions-suggestions.cc
@@ -22,6 +22,7 @@
 #include "actions/types.h"
 #include "actions/utils.h"
 #include "actions/zlib-utils.h"
+#include "annotator/collections.h"
 #include "utils/base/logging.h"
 #include "utils/flatbuffers.h"
 #include "utils/lua-utils.h"
@@ -50,6 +51,11 @@
     *[]() { return new std::string("send_email"); }();
 const std::string& ActionsSuggestions::kShareLocation =
     *[]() { return new std::string("share_location"); }();
+
+// Name for a datetime annotation that only includes time but no date.
+const std::string& kTimeAnnotation =
+    *[]() { return new std::string("time"); }();
+
 constexpr float kDefaultFloat = 0.0;
 constexpr bool kDefaultBool = false;
 constexpr int kDefaultInt = 1;
@@ -260,6 +266,7 @@
     }
   }
 
+  // Gather annotation entities for the rules.
   if (model_->annotation_actions_spec() != nullptr &&
       model_->annotation_actions_spec()->annotation_mapping() != nullptr) {
     for (const AnnotationActionsSpec_::AnnotationMapping* mapping :
@@ -300,6 +307,18 @@
     grammar_actions_.reset(new GrammarActions(
         unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(),
         model_->smart_reply_action_type()->str()));
+
+    // Gather annotation entities for the grammars.
+    if (auto annotation_nt = model_->rules()
+                                 ->grammar_rules()
+                                 ->rules()
+                                 ->nonterminals()
+                                 ->annotation_nt()) {
+      for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry :
+           *annotation_nt) {
+        annotation_entity_types_.insert(entry->key()->str());
+      }
+    }
   }
 
   std::string actions_script;
@@ -689,47 +708,41 @@
           interpreter->tensor(interpreter->inputs()[param_index])->type;
       const auto param_value_it = model_parameters.find(param_name);
       const bool has_value = param_value_it != model_parameters.end();
-      /*
-      case kTfLiteInt16:
-        *tflite::GetTensorData<int16_t>(input_tensor) = input_value;
-        break;
-      case kTfLiteInt8:
-       */
       switch (param_type) {
         case kTfLiteFloat32:
           model_executor_->SetInput<float>(
               param_index,
-              has_value ? param_value_it->second.FloatValue() : kDefaultFloat,
+              has_value ? param_value_it->second.Value<float>() : kDefaultFloat,
               interpreter);
           break;
         case kTfLiteInt32:
           model_executor_->SetInput<int32_t>(
               param_index,
-              has_value ? param_value_it->second.IntValue() : kDefaultInt,
+              has_value ? param_value_it->second.Value<int>() : kDefaultInt,
               interpreter);
           break;
         case kTfLiteInt64:
           model_executor_->SetInput<int64_t>(
               param_index,
-              has_value ? param_value_it->second.Int64Value() : kDefaultInt,
+              has_value ? param_value_it->second.Value<int64>() : kDefaultInt,
               interpreter);
           break;
         case kTfLiteUInt8:
           model_executor_->SetInput<uint8_t>(
               param_index,
-              has_value ? param_value_it->second.UInt8Value() : kDefaultInt,
+              has_value ? param_value_it->second.Value<uint8>() : kDefaultInt,
               interpreter);
           break;
         case kTfLiteInt8:
           model_executor_->SetInput<int8_t>(
               param_index,
-              has_value ? param_value_it->second.Int8Value() : kDefaultInt,
+              has_value ? param_value_it->second.Value<int8>() : kDefaultInt,
               interpreter);
           break;
         case kTfLiteBool:
           model_executor_->SetInput<bool>(
               param_index,
-              has_value ? param_value_it->second.BoolValue() : kDefaultBool,
+              has_value ? param_value_it->second.Value<bool>() : kDefaultBool,
               interpreter);
           break;
         default:
@@ -1023,6 +1036,30 @@
     if (message->annotations.empty()) {
       message->annotations = annotator->Annotate(
           message->text, AnnotationOptionsForMessage(*message));
+      for (int i = 0; i < message->annotations.size(); i++) {
+        ClassificationResult* classification =
+            &message->annotations[i].classification.front();
+
+        // Specialize datetime annotation to time annotation if no date
+        // component is present.
+        if (classification->collection == Collections::DateTime() &&
+            classification->datetime_parse_result.IsSet()) {
+          bool has_only_time = true;
+          for (const DatetimeComponent& component :
+               classification->datetime_parse_result.datetime_components) {
+            if (component.component_type !=
+                    DatetimeComponent::ComponentType::UNSPECIFIED &&
+                component.component_type <
+                    DatetimeComponent::ComponentType::HOUR) {
+              has_only_time = false;
+              break;
+            }
+          }
+          if (has_only_time) {
+            classification->collection = kTimeAnnotation;
+          }
+        }
+      }
     }
   }
   return annotated_conversation;
@@ -1224,6 +1261,13 @@
 
   SuggestActionsFromAnnotations(annotated_conversation, &response->actions);
 
+  if (grammar_actions_ != nullptr &&
+      !grammar_actions_->SuggestActions(annotated_conversation,
+                                        &response->actions)) {
+    TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
+    return false;
+  }
+
   int input_text_length = 0;
   int num_matching_locales = 0;
   for (int i = annotated_conversation.messages.size() - num_messages;
@@ -1299,13 +1343,6 @@
     return false;
   }
 
-  if (grammar_actions_ != nullptr &&
-      !grammar_actions_->SuggestActions(annotated_conversation,
-                                        &response->actions)) {
-    TC3_LOG(ERROR) << "Could not suggest actions from grammar rules.";
-    return false;
-  }
-
   if (preconditions_.suppress_on_low_confidence_input &&
       !regex_actions_->FilterConfidenceOutput(post_check_rules,
                                               &response->actions)) {
diff --git a/native/annotator/annotator.cc b/native/annotator/annotator.cc
index 4b11c7e..6ee983f 100644
--- a/native/annotator/annotator.cc
+++ b/native/annotator/annotator.cc
@@ -699,8 +699,8 @@
 
 bool Annotator::InitializeExperimentalAnnotators() {
   if (ExperimentalAnnotator::IsEnabled()) {
-    experimental_annotator_.reset(
-        new ExperimentalAnnotator(*selection_feature_processor_, *unilib_));
+    experimental_annotator_.reset(new ExperimentalAnnotator(
+        model_->experimental_model(), *selection_feature_processor_, *unilib_));
     return true;
   }
   return false;
@@ -2496,13 +2496,22 @@
       LoadAndVerifyMutableFlatbuffer<libtextclassifier3::EntityData>(
           *serialized_entity_data);
   if (data == nullptr) {
-    TC3_LOG(ERROR)
-        << "Data field is null when trying to parse Money Entity Data";
+    if (model_->version() >= 706) {
+      // This way of parsing money entity data is enabled for models newer than
+      // v706, consequently logging errors only for them (b/156634162).
+      TC3_LOG(ERROR)
+          << "Data field is null when trying to parse Money Entity Data";
+    }
     return false;
   }
   if (data->money->unnormalized_amount.empty()) {
-    TC3_LOG(ERROR) << "Data unnormalized_amount is empty when trying to parse "
-                      "Money Entity Data";
+    if (model_->version() >= 706) {
+      // This way of parsing money entity data is enabled for models newer than
+      // v706, consequently logging errors only for them (b/156634162).
+      TC3_LOG(ERROR)
+          << "Data unnormalized_amount is empty when trying to parse "
+             "Money Entity Data";
+    }
     return false;
   }
 
@@ -2593,7 +2602,11 @@
         if (regex_pattern.config->collection_name()->str() ==
             Collections::Money()) {
           if (!ParseAndFillInMoneyAmount(&serialized_entity_data)) {
-            TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
+            if (model_->version() >= 706) {
+              // This way of parsing money entity data is enabled for models
+              // newer than v706 => logging errors only for them (b/156634162).
+              TC3_LOG(ERROR) << "Could not parse and fill in money amount.";
+            }
           }
         }
       }
diff --git a/native/annotator/datetime/extractor.cc b/native/annotator/datetime/extractor.cc
index ebcf091..b8e1b7a 100644
--- a/native/annotator/datetime/extractor.cc
+++ b/native/annotator/datetime/extractor.cc
@@ -473,6 +473,7 @@
                       {DatetimeExtractorType_NEXT, 1},
                       {DatetimeExtractorType_NEXT_OR_SAME, 1},
                       {DatetimeExtractorType_LAST, -1},
+                      {DatetimeExtractorType_PAST, -1},
                   },
                   relative_count);
 }
diff --git a/native/annotator/experimental/experimental-dummy.h b/native/annotator/experimental/experimental-dummy.h
index 0d50bca..389aae1 100644
--- a/native/annotator/experimental/experimental-dummy.h
+++ b/native/annotator/experimental/experimental-dummy.h
@@ -33,7 +33,8 @@
   // always disabled;
   static constexpr bool IsEnabled() { return false; }
 
-  explicit ExperimentalAnnotator(const FeatureProcessor& feature_processor,
+  explicit ExperimentalAnnotator(const ExperimentalModel* model,
+                                 const FeatureProcessor& feature_processor,
                                  const UniLib& unilib) {}
 
   bool Annotate(const UnicodeText& context,
diff --git a/native/annotator/experimental/experimental.fbs b/native/annotator/experimental/experimental.fbs
index fff2d9e..6e15d04 100755
--- a/native/annotator/experimental/experimental.fbs
+++ b/native/annotator/experimental/experimental.fbs
@@ -1,3 +1,19 @@
+//
+// Copyright (C) 2018 The Android Open Source Project
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+
 namespace libtextclassifier3;
 table ExperimentalModel {
 }
diff --git a/native/annotator/grammar/dates/utils/date-match.cc b/native/annotator/grammar/dates/utils/date-match.cc
index 1ab1e6a..d9fca52 100644
--- a/native/annotator/grammar/dates/utils/date-match.cc
+++ b/native/annotator/grammar/dates/utils/date-match.cc
@@ -225,6 +225,18 @@
   return DatetimeComponent::RelativeQualifier::UNSPECIFIED;
 }
 
+// Embed RelativeQualifier information of DatetimeComponent as a sign of
+// relative counter field of datetime component i.e. relative counter is
+// negative when relative qualifier RelativeQualifier::PAST.
+int GetAdjustedRelativeCounter(
+    const DatetimeComponent::RelativeQualifier& relative_qualifier,
+    const int relative_counter) {
+  if (DatetimeComponent::RelativeQualifier::PAST == relative_qualifier) {
+    return -relative_counter;
+  }
+  return relative_counter;
+}
+
 Optional<DatetimeComponent> CreateDatetimeComponent(
     const DatetimeComponent::ComponentType& component_type,
     const DatetimeComponent::RelativeQualifier& relative_qualifier,
@@ -232,13 +244,15 @@
   if (absolute_value == NO_VAL && relative_value == NO_VAL) {
     return Optional<DatetimeComponent>();
   }
-  return Optional<DatetimeComponent>(
-      DatetimeComponent(component_type,
-                        (relative_value != NO_VAL)
-                            ? relative_qualifier
-                            : DatetimeComponent::RelativeQualifier::UNSPECIFIED,
-                        (absolute_value != NO_VAL) ? absolute_value : 0,
-                        (relative_value != NO_VAL) ? relative_value : 0));
+  return Optional<DatetimeComponent>(DatetimeComponent(
+      component_type,
+      (relative_value != NO_VAL)
+          ? relative_qualifier
+          : DatetimeComponent::RelativeQualifier::UNSPECIFIED,
+      (absolute_value != NO_VAL) ? absolute_value : 0,
+      (relative_value != NO_VAL)
+          ? GetAdjustedRelativeCounter(relative_qualifier, relative_value)
+          : 0));
 }
 
 Optional<DatetimeComponent> CreateDayOfWeekComponent(
diff --git a/native/utils/calendar/calendar-common.h b/native/utils/calendar/calendar-common.h
index f842300..e6fd076 100644
--- a/native/utils/calendar/calendar-common.h
+++ b/native/utils/calendar/calendar-common.h
@@ -229,7 +229,7 @@
     case DatetimeComponent::RelativeQualifier::PAST:
       TC3_CALENDAR_CHECK(
           AdjustByRelation(relative_date_time_component,
-                           -relative_date_time_component.relative_count,
+                           relative_date_time_component.relative_count,
                            /*allow_today=*/false, calendar))
       return true;
     case DatetimeComponent::RelativeQualifier::FUTURE:
diff --git a/native/utils/flatbuffers.cc b/native/utils/flatbuffers.cc
index 1cf60a9..cf4c97f 100644
--- a/native/utils/flatbuffers.cc
+++ b/native/utils/flatbuffers.cc
@@ -24,49 +24,6 @@
 
 namespace libtextclassifier3 {
 namespace {
-bool CreateRepeatedField(const reflection::Schema* schema,
-                         const reflection::Type* type,
-                         std::unique_ptr<RepeatedField>* repeated_field) {
-  switch (type->element()) {
-    case reflection::Bool:
-      repeated_field->reset(new TypedRepeatedField<bool>);
-      return true;
-    case reflection::Byte:
-      repeated_field->reset(new TypedRepeatedField<char>);
-      return true;
-    case reflection::UByte:
-      repeated_field->reset(new TypedRepeatedField<unsigned char>);
-      return true;
-    case reflection::Int:
-      repeated_field->reset(new TypedRepeatedField<int>);
-      return true;
-    case reflection::UInt:
-      repeated_field->reset(new TypedRepeatedField<uint>);
-      return true;
-    case reflection::Long:
-      repeated_field->reset(new TypedRepeatedField<int64>);
-      return true;
-    case reflection::ULong:
-      repeated_field->reset(new TypedRepeatedField<uint64>);
-      return true;
-    case reflection::Float:
-      repeated_field->reset(new TypedRepeatedField<float>);
-      return true;
-    case reflection::Double:
-      repeated_field->reset(new TypedRepeatedField<double>);
-      return true;
-    case reflection::String:
-      repeated_field->reset(new TypedRepeatedField<std::string>);
-      return true;
-    case reflection::Obj:
-      repeated_field->reset(
-          new TypedRepeatedField<ReflectiveFlatbuffer>(schema, type));
-      return true;
-    default:
-      TC3_LOG(ERROR) << "Unsupported type: " << type->element();
-      return false;
-  }
-}
 
 // Gets the field information for a field name, returns nullptr if the
 // field was not defined.
@@ -76,8 +33,8 @@
   return type->fields()->LookupByKey(field_name.data());
 }
 
-const reflection::Field* GetFieldByOffsetOrNull(const reflection::Object* type,
-                                                const int field_offset) {
+const reflection::Field* GetFieldOrNull(const reflection::Object* type,
+                                        const int field_offset) {
   if (type->fields() == nullptr) {
     return nullptr;
   }
@@ -97,14 +54,14 @@
   if (!field_name.empty()) {
     return GetFieldOrNull(type, field_name.data());
   }
-  return GetFieldByOffsetOrNull(type, field_offset);
+  return GetFieldOrNull(type, field_offset);
 }
 
 const reflection::Field* GetFieldOrNull(const reflection::Object* type,
                                         const FlatbufferField* field) {
   TC3_CHECK(type != nullptr && field != nullptr);
   if (field->field_name() == nullptr) {
-    return GetFieldByOffsetOrNull(type, field->field_offset());
+    return GetFieldOrNull(type, field->field_offset());
   }
   return GetFieldOrNull(
       type,
@@ -154,7 +111,7 @@
     return false;
   }
   if (field->type()->base_type() == reflection::Vector) {
-    buffer->Repeated<T>(field)->Add(value);
+    buffer->Repeated(field)->Add(value);
     return true;
   } else {
     return buffer->Set<T>(field, value);
@@ -221,9 +178,9 @@
   return true;
 }
 
-const reflection::Field* ReflectiveFlatbuffer::GetFieldByOffsetOrNull(
+const reflection::Field* ReflectiveFlatbuffer::GetFieldOrNull(
     const int field_offset) const {
-  return libtextclassifier3::GetFieldByOffsetOrNull(type_, field_offset);
+  return libtextclassifier3::GetFieldOrNull(type_, field_offset);
 }
 
 bool ReflectiveFlatbuffer::ParseAndSet(const reflection::Field* field,
@@ -257,6 +214,27 @@
   return parent->ParseAndSet(field, value);
 }
 
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(StringPiece field_name) {
+  const reflection::Field* field = GetFieldOrNull(field_name);
+  if (field == nullptr) {
+    return nullptr;
+  }
+
+  if (field->type()->base_type() != reflection::BaseType::Vector) {
+    return nullptr;
+  }
+
+  return Add(field);
+}
+
+ReflectiveFlatbuffer* ReflectiveFlatbuffer::Add(
+    const reflection::Field* field) {
+  if (field == nullptr) {
+    return nullptr;
+  }
+  return Repeated(field)->Add();
+}
+
 ReflectiveFlatbuffer* ReflectiveFlatbuffer::Mutable(
     const StringPiece field_name) {
   if (const reflection::Field* field = GetFieldOrNull(field_name)) {
@@ -306,11 +284,8 @@
   }
 
   // Otherwise, create a new instance and store it.
-  std::unique_ptr<RepeatedField> repeated_field;
-  if (!CreateRepeatedField(schema_, field->type(), &repeated_field)) {
-    TC3_LOG(ERROR) << "Could not create repeated field.";
-    return nullptr;
-  }
+  std::unique_ptr<RepeatedField> repeated_field(
+      new RepeatedField(schema_, field));
   const auto it = repeated_fields_.insert(
       /*hint=*/entry, std::make_pair(field, std::move(repeated_field)));
   return it->second.get();
@@ -330,9 +305,10 @@
 
   // Create strings.
   for (const auto& it : fields_) {
-    if (it.second.HasString()) {
-      offsets.push_back({it.first->offset(),
-                         builder->CreateString(it.second.StringValue()).o});
+    if (it.second.Has<std::string>()) {
+      offsets.push_back(
+          {it.first->offset(),
+           builder->CreateString(it.second.ConstRefValue<std::string>()).o});
     }
   }
 
@@ -349,44 +325,46 @@
     switch (it.second.GetType()) {
       case Variant::TYPE_BOOL_VALUE:
         builder->AddElement<uint8_t>(
-            it.first->offset(), static_cast<uint8_t>(it.second.BoolValue()),
+            it.first->offset(), static_cast<uint8_t>(it.second.Value<bool>()),
             static_cast<uint8_t>(it.first->default_integer()));
         continue;
       case Variant::TYPE_INT8_VALUE:
         builder->AddElement<int8_t>(
-            it.first->offset(), static_cast<int8_t>(it.second.Int8Value()),
+            it.first->offset(), static_cast<int8_t>(it.second.Value<int8>()),
             static_cast<int8_t>(it.first->default_integer()));
         continue;
       case Variant::TYPE_UINT8_VALUE:
         builder->AddElement<uint8_t>(
-            it.first->offset(), static_cast<uint8_t>(it.second.UInt8Value()),
+            it.first->offset(), static_cast<uint8_t>(it.second.Value<uint8>()),
             static_cast<uint8_t>(it.first->default_integer()));
         continue;
       case Variant::TYPE_INT_VALUE:
         builder->AddElement<int32>(
-            it.first->offset(), it.second.IntValue(),
+            it.first->offset(), it.second.Value<int>(),
             static_cast<int32>(it.first->default_integer()));
         continue;
       case Variant::TYPE_UINT_VALUE:
         builder->AddElement<uint32>(
-            it.first->offset(), it.second.UIntValue(),
+            it.first->offset(), it.second.Value<uint>(),
             static_cast<uint32>(it.first->default_integer()));
         continue;
       case Variant::TYPE_INT64_VALUE:
-        builder->AddElement<int64>(it.first->offset(), it.second.Int64Value(),
+        builder->AddElement<int64>(it.first->offset(), it.second.Value<int64>(),
                                    it.first->default_integer());
         continue;
       case Variant::TYPE_UINT64_VALUE:
-        builder->AddElement<uint64>(it.first->offset(), it.second.UInt64Value(),
+        builder->AddElement<uint64>(it.first->offset(),
+                                    it.second.Value<uint64>(),
                                     it.first->default_integer());
         continue;
       case Variant::TYPE_FLOAT_VALUE:
         builder->AddElement<float>(
-            it.first->offset(), it.second.FloatValue(),
+            it.first->offset(), it.second.Value<float>(),
             static_cast<float>(it.first->default_real()));
         continue;
       case Variant::TYPE_DOUBLE_VALUE:
-        builder->AddElement<double>(it.first->offset(), it.second.DoubleValue(),
+        builder->AddElement<double>(it.first->offset(),
+                                    it.second.Value<double>(),
                                     it.first->default_real());
         continue;
       default:
@@ -419,7 +397,7 @@
     return false;
   }
 
-  TypedRepeatedField<std::string>* to_repeated = Repeated<std::string>(field);
+  RepeatedField* to_repeated = Repeated(field);
   for (const flatbuffers::String* element : *from_vector) {
     to_repeated->Add(element->str());
   }
@@ -435,8 +413,7 @@
     return false;
   }
 
-  TypedRepeatedField<ReflectiveFlatbuffer>* to_repeated =
-      Repeated<ReflectiveFlatbuffer>(field);
+  RepeatedField* to_repeated = Repeated(field);
   for (const flatbuffers::Table* const from_element : *from_vector) {
     ReflectiveFlatbuffer* to_element = to_repeated->Add();
     if (to_element == nullptr) {
@@ -502,7 +479,9 @@
                        ->str());
         break;
       case reflection::Obj:
-        if (!Mutable(field)->MergeFrom(
+        if (ReflectiveFlatbuffer* nested_field = Mutable(field);
+            nested_field == nullptr ||
+            !nested_field->MergeFrom(
                 from->GetPointer<const flatbuffers::Table* const>(
                     field->offset()))) {
           return false;
@@ -635,4 +614,96 @@
   return true;
 }
 
+//
+// Repeated field methods.
+//
+
+ReflectiveFlatbuffer* RepeatedField::Add() {
+  if (is_primitive_) {
+    TC3_LOG(ERROR) << "Trying to add sub-message on a primitive-typed field.";
+    return nullptr;
+  }
+
+  object_items_.emplace_back(new ReflectiveFlatbuffer(
+      schema_, schema_->objects()->Get(field_->type()->index())));
+  return object_items_.back().get();
+}
+
+namespace {
+
+template <typename T>
+flatbuffers::uoffset_t TypedSerialize(const std::vector<Variant>& values,
+                                      flatbuffers::FlatBufferBuilder* builder) {
+  std::vector<T> typed_values;
+  typed_values.reserve(values.size());
+  for (const Variant& item : values) {
+    typed_values.push_back(item.Value<T>());
+  }
+  return builder->CreateVector(typed_values).o;
+}
+
+}  // namespace
+
+flatbuffers::uoffset_t RepeatedField::Serialize(
+    flatbuffers::FlatBufferBuilder* builder) const {
+  switch (field_->type()->element()) {
+    case reflection::String:
+      return SerializeString(builder);
+      break;
+    case reflection::Obj:
+      return SerializeObject(builder);
+      break;
+    case reflection::Bool:
+      return TypedSerialize<bool>(items_, builder);
+      break;
+    case reflection::Byte:
+      return TypedSerialize<int8_t>(items_, builder);
+      break;
+    case reflection::UByte:
+      return TypedSerialize<uint8_t>(items_, builder);
+      break;
+    case reflection::Int:
+      return TypedSerialize<int>(items_, builder);
+      break;
+    case reflection::UInt:
+      return TypedSerialize<uint>(items_, builder);
+      break;
+    case reflection::Long:
+      return TypedSerialize<int64>(items_, builder);
+      break;
+    case reflection::ULong:
+      return TypedSerialize<uint64>(items_, builder);
+      break;
+    case reflection::Float:
+      return TypedSerialize<float>(items_, builder);
+      break;
+    case reflection::Double:
+      return TypedSerialize<double>(items_, builder);
+      break;
+    default:
+      TC3_LOG(FATAL) << "Unsupported type: " << field_->type()->element();
+      break;
+  }
+  TC3_LOG(FATAL) << "Invalid state.";
+  return 0;
+}
+
+flatbuffers::uoffset_t RepeatedField::SerializeString(
+    flatbuffers::FlatBufferBuilder* builder) const {
+  std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(items_.size());
+  for (int i = 0; i < items_.size(); i++) {
+    offsets[i] = builder->CreateString(items_[i].ConstRefValue<std::string>());
+  }
+  return builder->CreateVector(offsets).o;
+}
+
+flatbuffers::uoffset_t RepeatedField::SerializeObject(
+    flatbuffers::FlatBufferBuilder* builder) const {
+  std::vector<flatbuffers::Offset<void>> offsets(object_items_.size());
+  for (int i = 0; i < object_items_.size(); i++) {
+    offsets[i] = object_items_[i]->Serialize(builder);
+  }
+  return builder->CreateVector(offsets).o;
+}
+
 }  // namespace libtextclassifier3
diff --git a/native/utils/flatbuffers.h b/native/utils/flatbuffers.h
index 93a4109..aaf248e 100644
--- a/native/utils/flatbuffers.h
+++ b/native/utils/flatbuffers.h
@@ -19,7 +19,6 @@
 #ifndef LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
 #define LIBTEXTCLASSIFIER_UTILS_FLATBUFFERS_H_
 
-#include <map>
 #include <memory>
 #include <string>
 #include <unordered_map>
@@ -31,13 +30,12 @@
 #include "utils/variant.h"
 #include "flatbuffers/flatbuffers.h"
 #include "flatbuffers/reflection.h"
+#include "flatbuffers/reflection_generated.h"
 
 namespace libtextclassifier3 {
 
 class ReflectiveFlatBuffer;
 class RepeatedField;
-template <typename T>
-class TypedRepeatedField;
 
 // Loads and interprets the buffer as 'FlatbufferMessage' and verifies its
 // integrity.
@@ -105,6 +103,41 @@
                      builder.GetSize());
 }
 
+class ReflectiveFlatbuffer;
+
+// Checks whether a variant value type agrees with a field type.
+template <typename T>
+bool IsMatchingType(const reflection::BaseType type) {
+  switch (type) {
+    case reflection::Bool:
+      return std::is_same<T, bool>::value;
+    case reflection::Byte:
+      return std::is_same<T, int8>::value;
+    case reflection::UByte:
+      return std::is_same<T, uint8>::value;
+    case reflection::Int:
+      return std::is_same<T, int32>::value;
+    case reflection::UInt:
+      return std::is_same<T, uint32>::value;
+    case reflection::Long:
+      return std::is_same<T, int64>::value;
+    case reflection::ULong:
+      return std::is_same<T, uint64>::value;
+    case reflection::Float:
+      return std::is_same<T, float>::value;
+    case reflection::Double:
+      return std::is_same<T, double>::value;
+    case reflection::String:
+      return std::is_same<T, std::string>::value ||
+             std::is_same<T, StringPiece>::value ||
+             std::is_same<T, const char*>::value;
+    case reflection::Obj:
+      return std::is_same<T, ReflectiveFlatbuffer>::value;
+    default:
+      return false;
+  }
+}
+
 // A flatbuffer that can be built using flatbuffer reflection data of the
 // schema.
 // Normally, field information is hard-coded in code generated from a flatbuffer
@@ -123,119 +156,58 @@
   // field was not defined.
   const reflection::Field* GetFieldOrNull(const StringPiece field_name) const;
   const reflection::Field* GetFieldOrNull(const FlatbufferField* field) const;
-  const reflection::Field* GetFieldByOffsetOrNull(const int field_offset) const;
+  const reflection::Field* GetFieldOrNull(const int field_offset) const;
 
   // Gets a nested field and the message it is defined on.
   bool GetFieldWithParent(const FlatbufferFieldPath* field_path,
                           ReflectiveFlatbuffer** parent,
                           reflection::Field const** field);
 
-  // Checks whether a variant value type agrees with a field type.
-  template <typename T>
-  bool IsMatchingType(const reflection::BaseType type) const {
-    switch (type) {
-      case reflection::Bool:
-        return std::is_same<T, bool>::value;
-      case reflection::Byte:
-        return std::is_same<T, int8>::value;
-      case reflection::UByte:
-        return std::is_same<T, uint8>::value;
-      case reflection::Int:
-        return std::is_same<T, int32>::value;
-      case reflection::UInt:
-        return std::is_same<T, uint32>::value;
-      case reflection::Long:
-        return std::is_same<T, int64>::value;
-      case reflection::ULong:
-        return std::is_same<T, uint64>::value;
-      case reflection::Float:
-        return std::is_same<T, float>::value;
-      case reflection::Double:
-        return std::is_same<T, double>::value;
-      case reflection::String:
-        return std::is_same<T, std::string>::value ||
-               std::is_same<T, StringPiece>::value ||
-               std::is_same<T, const char*>::value;
-      case reflection::Obj:
-        return std::is_same<T, ReflectiveFlatbuffer>::value;
-      default:
-        return false;
-    }
-  }
-
-  // Sets a (primitive) field to a specific value.
+  // Sets a field to a specific value.
   // Returns true if successful, and false if the field was not found or the
   // expected type doesn't match.
   template <typename T>
-  bool Set(StringPiece field_name, T value) {
-    if (const reflection::Field* field = GetFieldOrNull(field_name)) {
-      return Set<T>(field, value);
-    }
-    return false;
-  }
+  bool Set(StringPiece field_name, T value);
 
-  // Sets a (primitive) field to a specific value.
+  // Sets a field to a specific value.
   // Returns true if successful, and false if the expected type doesn't match.
   // Expects `field` to be non-null.
   template <typename T>
-  bool Set(const reflection::Field* field, T value) {
-    if (field == nullptr) {
-      TC3_LOG(ERROR) << "Expected non-null field.";
-      return false;
-    }
-    Variant variant_value(value);
-    if (!IsMatchingType<T>(field->type()->base_type())) {
-      TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
-                     << "`, expected: " << field->type()->base_type()
-                     << ", got: " << variant_value.GetType();
-      return false;
-    }
-    fields_[field] = variant_value;
-    return true;
-  }
+  bool Set(const reflection::Field* field, T value);
 
+  // Sets a field to a specific value. Field is specified by path.
   template <typename T>
-  bool Set(const FlatbufferFieldPath* path, T value) {
-    ReflectiveFlatbuffer* parent;
-    const reflection::Field* field;
-    if (!GetFieldWithParent(path, &parent, &field)) {
-      return false;
-    }
-    return parent->Set<T>(field, value);
-  }
+  bool Set(const FlatbufferFieldPath* path, T value);
 
-  // Sets a (primitive) field to a specific value.
-  // Parses the string value according to the field type.
-  bool ParseAndSet(const reflection::Field* field, const std::string& value);
-  bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
-
-  // Gets the reflective flatbuffer for a table field.
+  // Sets sub-message field (if not set yet), and returns a pointer to it.
   // Returns nullptr if the field was not found, or the field type was not a
   // table.
   ReflectiveFlatbuffer* Mutable(StringPiece field_name);
   ReflectiveFlatbuffer* Mutable(const reflection::Field* field);
 
+  // Parses the value (according to the type) and sets a primitive field to the
+  // parsed value.
+  bool ParseAndSet(const reflection::Field* field, const std::string& value);
+  bool ParseAndSet(const FlatbufferFieldPath* path, const std::string& value);
+
+  // Adds a primitive value to the repeated field.
+  template <typename T>
+  bool Add(StringPiece field_name, T value);
+
+  // Add a sub-message to the repeated field.
+  ReflectiveFlatbuffer* Add(StringPiece field_name);
+
+  template <typename T>
+  bool Add(const reflection::Field* field, T value);
+
+  ReflectiveFlatbuffer* Add(const reflection::Field* field);
+
   // Gets the reflective flatbuffer for a repeated field.
   // Returns nullptr if the field was not found, or the field type was not a
   // vector.
   RepeatedField* Repeated(StringPiece field_name);
   RepeatedField* Repeated(const reflection::Field* field);
 
-  template <typename T>
-  TypedRepeatedField<T>* Repeated(const reflection::Field* field) {
-    if (!IsMatchingType<T>(field->type()->element())) {
-      TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
-                     << "`";
-      return nullptr;
-    }
-    return static_cast<TypedRepeatedField<T>*>(Repeated(field));
-  }
-
-  template <typename T>
-  TypedRepeatedField<T>* Repeated(StringPiece field_name) {
-    return static_cast<TypedRepeatedField<T>*>(Repeated(field_name));
-  }
-
   // Serializes the flatbuffer.
   flatbuffers::uoffset_t Serialize(
       flatbuffers::FlatBufferBuilder* builder) const;
@@ -318,77 +290,132 @@
 // Serves as a common base class for repeated fields.
 class RepeatedField {
  public:
-  virtual ~RepeatedField() {}
+  RepeatedField(const reflection::Schema* const schema,
+                const reflection::Field* field)
+      : schema_(schema),
+        field_(field),
+        is_primitive_(field->type()->element() != reflection::BaseType::Obj) {}
 
-  virtual flatbuffers::uoffset_t Serialize(
-      flatbuffers::FlatBufferBuilder* builder) const = 0;
-};
+  template <typename T>
+  bool Add(const T value);
 
-// Represents a repeated field of particular type.
-template <typename T>
-class TypedRepeatedField : public RepeatedField {
- public:
-  void Add(const T value) { items_.push_back(value); }
+  ReflectiveFlatbuffer* Add();
 
-  flatbuffers::uoffset_t Serialize(
-      flatbuffers::FlatBufferBuilder* builder) const override {
-    return builder->CreateVector(items_).o;
+  template <typename T>
+  T Get(int index) const {
+    return items_.at(index).Value<T>();
   }
 
- private:
-  std::vector<T> items_;
-};
-
-// Specialization for strings.
-template <>
-class TypedRepeatedField<std::string> : public RepeatedField {
- public:
-  void Add(const std::string& value) { items_.push_back(value); }
-
-  flatbuffers::uoffset_t Serialize(
-      flatbuffers::FlatBufferBuilder* builder) const override {
-    std::vector<flatbuffers::Offset<flatbuffers::String>> offsets(
-        items_.size());
-    for (int i = 0; i < items_.size(); i++) {
-      offsets[i] = builder->CreateString(items_[i]);
+  template <>
+  ReflectiveFlatbuffer* Get(int index) const {
+    if (is_primitive_) {
+      TC3_LOG(ERROR) << "Trying to get primitive value out of non-primitive "
+                        "repeated field.";
+      return nullptr;
     }
-    return builder->CreateVector(offsets).o;
+    return object_items_.at(index).get();
   }
 
- private:
-  std::vector<std::string> items_;
-};
-
-// Specialization for repeated sub-messages.
-template <>
-class TypedRepeatedField<ReflectiveFlatbuffer> : public RepeatedField {
- public:
-  TypedRepeatedField<ReflectiveFlatbuffer>(
-      const reflection::Schema* const schema,
-      const reflection::Type* const type)
-      : schema_(schema), type_(type) {}
-
-  ReflectiveFlatbuffer* Add() {
-    items_.emplace_back(new ReflectiveFlatbuffer(
-        schema_, schema_->objects()->Get(type_->index())));
-    return items_.back().get();
+  int Size() const {
+    if (is_primitive_) {
+      return items_.size();
+    } else {
+      return object_items_.size();
+    }
   }
 
   flatbuffers::uoffset_t Serialize(
-      flatbuffers::FlatBufferBuilder* builder) const override {
-    std::vector<flatbuffers::Offset<void>> offsets(items_.size());
-    for (int i = 0; i < items_.size(); i++) {
-      offsets[i] = items_[i]->Serialize(builder);
-    }
-    return builder->CreateVector(offsets).o;
-  }
+      flatbuffers::FlatBufferBuilder* builder) const;
 
  private:
+  flatbuffers::uoffset_t SerializeString(
+      flatbuffers::FlatBufferBuilder* builder) const;
+  flatbuffers::uoffset_t SerializeObject(
+      flatbuffers::FlatBufferBuilder* builder) const;
+
   const reflection::Schema* const schema_;
-  const reflection::Type* const type_;
-  std::vector<std::unique_ptr<ReflectiveFlatbuffer>> items_;
+  const reflection::Field* field_;
+  bool is_primitive_;
+
+  std::vector<Variant> items_;
+  std::vector<std::unique_ptr<ReflectiveFlatbuffer>> object_items_;
 };
 
+template <typename T>
+bool ReflectiveFlatbuffer::Set(StringPiece field_name, T value) {
+  if (const reflection::Field* field = GetFieldOrNull(field_name)) {
+    if (field->type()->base_type() == reflection::BaseType::Vector ||
+        field->type()->base_type() == reflection::BaseType::Obj) {
+      TC3_LOG(ERROR)
+          << "Trying to set a primitive value on a non-scalar field.";
+      return false;
+    }
+    return Set<T>(field, value);
+  }
+  TC3_LOG(ERROR) << "Couldn't find a field: " << field_name;
+  return false;
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Set(const reflection::Field* field, T value) {
+  if (field == nullptr) {
+    TC3_LOG(ERROR) << "Expected non-null field.";
+    return false;
+  }
+  Variant variant_value(value);
+  if (!IsMatchingType<T>(field->type()->base_type())) {
+    TC3_LOG(ERROR) << "Type mismatch for field `" << field->name()->str()
+                   << "`, expected: " << field->type()->base_type()
+                   << ", got: " << variant_value.GetType();
+    return false;
+  }
+  fields_[field] = variant_value;
+  return true;
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Set(const FlatbufferFieldPath* path, T value) {
+  ReflectiveFlatbuffer* parent;
+  const reflection::Field* field;
+  if (!GetFieldWithParent(path, &parent, &field)) {
+    return false;
+  }
+  return parent->Set<T>(field, value);
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Add(StringPiece field_name, T value) {
+  const reflection::Field* field = GetFieldOrNull(field_name);
+  if (field == nullptr) {
+    return false;
+  }
+
+  if (field->type()->base_type() != reflection::BaseType::Vector) {
+    return false;
+  }
+
+  return Add<T>(field, value);
+}
+
+template <typename T>
+bool ReflectiveFlatbuffer::Add(const reflection::Field* field, T value) {
+  if (field == nullptr) {
+    return false;
+  }
+  Repeated(field)->Add(value);
+  return true;
+}
+
+template <typename T>
+bool RepeatedField::Add(const T value) {
+  if (!is_primitive_ || !IsMatchingType<T>(field_->type()->element())) {
+    TC3_LOG(ERROR) << "Trying to add value of unmatching type.";
+    return false;
+  }
+  items_.push_back(Variant{value});
+  return true;
+}
+
 // Resolves field lookups by name to the concrete field offsets.
 bool SwapFieldNamesForOffsetsInPath(const reflection::Schema* schema,
                                     FlatbufferFieldPathT* path);
@@ -402,7 +429,7 @@
     return false;
   }
 
-  TypedRepeatedField<T>* to_repeated = Repeated<T>(field);
+  RepeatedField* to_repeated = Repeated(field);
   for (const T element : *from_vector) {
     to_repeated->Add(element);
   }
diff --git a/native/utils/grammar/utils/rules.cc b/native/utils/grammar/utils/rules.cc
index 69a06a8..d6e4b76 100644
--- a/native/utils/grammar/utils/rules.cc
+++ b/native/utils/grammar/utils/rules.cc
@@ -177,6 +177,13 @@
   return it->second;
 }
 
+void Rules::BindAnnotation(const std::string& nonterminal_name,
+                           const std::string& annotation_name) {
+  auto [_, inserted] = annotation_nonterminals_.insert(
+      {annotation_name, AddNonterminal(nonterminal_name)});
+  TC3_CHECK(inserted);
+}
+
 bool Rules::IsNonterminalOfName(const RhsElement& element,
                                 const std::string& nonterminal) const {
   if (element.is_terminal) {
diff --git a/native/utils/grammar/utils/rules.h b/native/utils/grammar/utils/rules.h
index 5cc20d7..5a2cbc2 100644
--- a/native/utils/grammar/utils/rules.h
+++ b/native/utils/grammar/utils/rules.h
@@ -153,6 +153,10 @@
   // Defines a nonterminal for an externally provided annotation.
   int AddAnnotation(const std::string& annotation_name);
 
+  // Defines a nonterminal for an externally provided annotation.
+  void BindAnnotation(const std::string& nonterminal_name,
+                      const std::string& annotation_name);
+
   // Adds an alias for a nonterminal. This is a separate name for the same
   // nonterminal.
   void AddAlias(const std::string& nonterminal_name, const std::string& alias);
diff --git a/native/utils/intents/jni.cc b/native/utils/intents/jni.cc
index 1c6c283..051d078 100644
--- a/native/utils/intents/jni.cc
+++ b/native/utils/intents/jni.cc
@@ -175,40 +175,41 @@
     case Variant::TYPE_INT_VALUE:
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_int_, name.get(),
-                                  value.IntValue());
+                                  value.Value<int>());
 
     case Variant::TYPE_INT64_VALUE:
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_long_, name.get(),
-                                  value.Int64Value());
+                                  value.Value<int64>());
 
     case Variant::TYPE_FLOAT_VALUE:
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_float_, name.get(),
-                                  value.FloatValue());
+                                  value.Value<float>());
 
     case Variant::TYPE_DOUBLE_VALUE:
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_double_, name.get(),
-                                  value.DoubleValue());
+                                  value.Value<double>());
 
     case Variant::TYPE_BOOL_VALUE:
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_bool_, name.get(),
-                                  value.BoolValue());
+                                  value.Value<bool>());
 
     case Variant::TYPE_STRING_VALUE: {
       TC3_ASSIGN_OR_RETURN(
           ScopedLocalRef<jstring> value_jstring,
-          jni_cache_->ConvertToJavaString(value.StringValue()));
+          jni_cache_->ConvertToJavaString(value.ConstRefValue<std::string>()));
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_string_, name.get(),
                                   value_jstring.get());
     }
 
     case Variant::TYPE_STRING_VECTOR_VALUE: {
-      TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> value_jstring_array,
-                           AsStringArray(value.StringVectorValue()));
+      TC3_ASSIGN_OR_RETURN(
+          ScopedLocalRef<jobjectArray> value_jstring_array,
+          AsStringArray(value.ConstRefValue<std::vector<std::string>>()));
 
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_string_array_, name.get(),
@@ -216,8 +217,9 @@
     }
 
     case Variant::TYPE_FLOAT_VECTOR_VALUE: {
-      TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jfloatArray> value_jfloat_array,
-                           AsFloatArray(value.FloatVectorValue()));
+      TC3_ASSIGN_OR_RETURN(
+          ScopedLocalRef<jfloatArray> value_jfloat_array,
+          AsFloatArray(value.ConstRefValue<std::vector<float>>()));
 
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_float_array_, name.get(),
@@ -226,7 +228,7 @@
 
     case Variant::TYPE_INT_VECTOR_VALUE: {
       TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jintArray> value_jint_array,
-                           AsIntArray(value.IntVectorValue()));
+                           AsIntArray(value.ConstRefValue<std::vector<int>>()));
 
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_int_array_, name.get(),
@@ -234,8 +236,10 @@
     }
 
     case Variant::TYPE_STRING_VARIANT_MAP_VALUE: {
-      TC3_ASSIGN_OR_RETURN(ScopedLocalRef<jobjectArray> value_jobect_array,
-                           AsNamedVariantArray(value.StringVariantMapValue()));
+      TC3_ASSIGN_OR_RETURN(
+          ScopedLocalRef<jobjectArray> value_jobect_array,
+          AsNamedVariantArray(
+              value.ConstRefValue<std::map<std::string, Variant>>()));
       return JniHelper::NewObject(env, named_variant_class_.get(),
                                   named_variant_from_named_variant_array_,
                                   name.get(), value_jobect_array.get());
diff --git a/native/utils/lua-utils.cc b/native/utils/lua-utils.cc
index fa19923..d6fe2c4 100644
--- a/native/utils/lua-utils.cc
+++ b/native/utils/lua-utils.cc
@@ -223,6 +223,11 @@
 
 int LuaEnvironment::ReadFlatbuffer(const int index,
                                    ReflectiveFlatbuffer* buffer) const {
+  if (buffer == nullptr) {
+    TC3_LOG(ERROR) << "Called ReadFlatbuffer with null buffer: " << index;
+    lua_error(state_);
+    return LUA_ERRRUN;
+  }
   if (lua_type(state_, /*idx=*/index) != LUA_TTABLE) {
     TC3_LOG(ERROR) << "Expected table, got: "
                    << lua_type(state_, /*idx=*/kIndexStackTop);
@@ -278,48 +283,48 @@
         // Read repeated field.
         switch (field->type()->element()) {
           case reflection::Bool:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<bool>(field));
+            ReadRepeatedField<bool>(/*index=*/kIndexStackTop,
+                                    buffer->Repeated(field));
             break;
           case reflection::Byte:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<int8>(field));
+            ReadRepeatedField<int8>(/*index=*/kIndexStackTop,
+                                    buffer->Repeated(field));
             break;
           case reflection::UByte:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<uint8>(field));
+            ReadRepeatedField<uint8>(/*index=*/kIndexStackTop,
+                                     buffer->Repeated(field));
             break;
           case reflection::Int:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<int32>(field));
+            ReadRepeatedField<int32>(/*index=*/kIndexStackTop,
+                                     buffer->Repeated(field));
             break;
           case reflection::UInt:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<uint32>(field));
+            ReadRepeatedField<uint32>(/*index=*/kIndexStackTop,
+                                      buffer->Repeated(field));
             break;
           case reflection::Long:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<int64>(field));
+            ReadRepeatedField<int64>(/*index=*/kIndexStackTop,
+                                     buffer->Repeated(field));
             break;
           case reflection::ULong:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<uint64>(field));
+            ReadRepeatedField<uint64>(/*index=*/kIndexStackTop,
+                                      buffer->Repeated(field));
             break;
           case reflection::Float:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<float>(field));
+            ReadRepeatedField<float>(/*index=*/kIndexStackTop,
+                                     buffer->Repeated(field));
             break;
           case reflection::Double:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<double>(field));
+            ReadRepeatedField<double>(/*index=*/kIndexStackTop,
+                                      buffer->Repeated(field));
             break;
           case reflection::String:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<std::string>(field));
+            ReadRepeatedField<std::string>(/*index=*/kIndexStackTop,
+                                           buffer->Repeated(field));
             break;
           case reflection::Obj:
-            ReadRepeatedField(/*index=*/kIndexStackTop,
-                              buffer->Repeated<ReflectiveFlatbuffer>(field));
+            ReadRepeatedField<ReflectiveFlatbuffer>(/*index=*/kIndexStackTop,
+                                                    buffer->Repeated(field));
             break;
           default:
             TC3_LOG(ERROR) << "Unsupported repeated field type: "
diff --git a/native/utils/lua-utils.h b/native/utils/lua-utils.h
index f602aa0..b01471a 100644
--- a/native/utils/lua-utils.h
+++ b/native/utils/lua-utils.h
@@ -506,15 +506,15 @@
 
   // Reads a repeated field from lua.
   template <typename T>
-  void ReadRepeatedField(const int index, TypedRepeatedField<T>* result) const {
+  void ReadRepeatedField(const int index, RepeatedField* result) const {
     for (const auto& element : ReadVector<T>(index)) {
       result->Add(element);
     }
   }
 
   template <>
-  void ReadRepeatedField<ReflectiveFlatbuffer>(
-      const int index, TypedRepeatedField<ReflectiveFlatbuffer>* result) const {
+  void ReadRepeatedField<ReflectiveFlatbuffer>(const int index,
+                                               RepeatedField* result) const {
     lua_pushnil(state_);
     while (Next(index - 1)) {
       ReadFlatbuffer(index, result->Add());
diff --git a/native/utils/variant.cc b/native/utils/variant.cc
index 9cdc0b6..0513440 100644
--- a/native/utils/variant.cc
+++ b/native/utils/variant.cc
@@ -21,26 +21,26 @@
 std::string Variant::ToString() const {
   switch (GetType()) {
     case Variant::TYPE_BOOL_VALUE:
-      if (BoolValue()) {
+      if (Value<bool>()) {
         return "true";
       } else {
         return "false";
       }
       break;
     case Variant::TYPE_INT_VALUE:
-      return std::to_string(IntValue());
+      return std::to_string(Value<int>());
       break;
     case Variant::TYPE_INT64_VALUE:
-      return std::to_string(Int64Value());
+      return std::to_string(Value<int64>());
       break;
     case Variant::TYPE_FLOAT_VALUE:
-      return std::to_string(FloatValue());
+      return std::to_string(Value<float>());
       break;
     case Variant::TYPE_DOUBLE_VALUE:
-      return std::to_string(DoubleValue());
+      return std::to_string(Value<double>());
       break;
     case Variant::TYPE_STRING_VALUE:
-      return StringValue();
+      return ConstRefValue<std::string>();
       break;
     default:
       TC3_LOG(FATAL) << "Unsupported variant type: " << GetType();
diff --git a/native/utils/variant.h b/native/utils/variant.h
index 11c361c..551a822 100644
--- a/native/utils/variant.h
+++ b/native/utils/variant.h
@@ -85,110 +85,178 @@
 
   Variant& operator=(const Variant&) = default;
 
-  int Int8Value() const {
-    TC3_CHECK(HasInt8());
+  template <class T>
+  struct dependent_false : std::false_type {};
+
+  template <typename T>
+  T Value() const {
+    static_assert(dependent_false<T>::value, "Not supported.");
+  }
+
+  template <>
+  int8 Value() const {
+    TC3_CHECK(Has<int8>());
     return int8_value_;
   }
 
-  int UInt8Value() const {
-    TC3_CHECK(HasUInt8());
+  template <>
+  uint8 Value() const {
+    TC3_CHECK(Has<uint8>());
     return uint8_value_;
   }
 
-  int IntValue() const {
-    TC3_CHECK(HasInt());
+  template <>
+  int Value() const {
+    TC3_CHECK(Has<int>());
     return int_value_;
   }
 
-  uint UIntValue() const {
-    TC3_CHECK(HasUInt());
+  template <>
+  uint Value() const {
+    TC3_CHECK(Has<uint>());
     return uint_value_;
   }
 
-  int64 Int64Value() const {
-    TC3_CHECK(HasInt64());
+  template <>
+  int64 Value() const {
+    TC3_CHECK(Has<int64>());
     return long_value_;
   }
 
-  uint64 UInt64Value() const {
-    TC3_CHECK(HasUInt64());
+  template <>
+  uint64 Value() const {
+    TC3_CHECK(Has<uint64>());
     return ulong_value_;
   }
 
-  float FloatValue() const {
-    TC3_CHECK(HasFloat());
+  template <>
+  float Value() const {
+    TC3_CHECK(Has<float>());
     return float_value_;
   }
 
-  double DoubleValue() const {
-    TC3_CHECK(HasDouble());
+  template <>
+  double Value() const {
+    TC3_CHECK(Has<double>());
     return double_value_;
   }
 
-  bool BoolValue() const {
-    TC3_CHECK(HasBool());
+  template <>
+  bool Value() const {
+    TC3_CHECK(Has<bool>());
     return bool_value_;
   }
 
-  const std::string& StringValue() const {
-    TC3_CHECK(HasString());
+  template <typename T>
+  const T& ConstRefValue() const;
+
+  template <>
+  const std::string& ConstRefValue() const {
+    TC3_CHECK(Has<std::string>());
     return string_value_;
   }
 
-  const std::vector<std::string>& StringVectorValue() const {
-    TC3_CHECK(HasStringVector());
+  template <>
+  const std::vector<std::string>& ConstRefValue() const {
+    TC3_CHECK(Has<std::vector<std::string>>());
     return string_vector_value_;
   }
 
-  const std::vector<float>& FloatVectorValue() const {
-    TC3_CHECK(HasFloatVector());
+  template <>
+  const std::vector<float>& ConstRefValue() const {
+    TC3_CHECK(Has<std::vector<float>>());
     return float_vector_value_;
   }
 
-  const std::vector<int>& IntVectorValue() const {
-    TC3_CHECK(HasIntVector());
+  template <>
+  const std::vector<int>& ConstRefValue() const {
+    TC3_CHECK(Has<std::vector<int>>());
     return int_vector_value_;
   }
 
-  const std::map<std::string, Variant>& StringVariantMapValue() const {
-    TC3_CHECK(HasStringVariantMap());
+  template <>
+  const std::map<std::string, Variant>& ConstRefValue() const {
+    TC3_CHECK((Has<std::map<std::string, Variant>>()));
     return string_variant_map_value_;
   }
 
+  template <typename T>
+  bool Has() const;
+
+  template <>
+  bool Has<int8>() const {
+    return type_ == TYPE_INT8_VALUE;
+  }
+
+  template <>
+  bool Has<uint8>() const {
+    return type_ == TYPE_UINT8_VALUE;
+  }
+
+  template <>
+  bool Has<int>() const {
+    return type_ == TYPE_INT_VALUE;
+  }
+
+  template <>
+  bool Has<uint>() const {
+    return type_ == TYPE_UINT_VALUE;
+  }
+
+  template <>
+  bool Has<int64>() const {
+    return type_ == TYPE_INT64_VALUE;
+  }
+
+  template <>
+  bool Has<uint64>() const {
+    return type_ == TYPE_UINT64_VALUE;
+  }
+
+  template <>
+  bool Has<float>() const {
+    return type_ == TYPE_FLOAT_VALUE;
+  }
+
+  template <>
+  bool Has<double>() const {
+    return type_ == TYPE_DOUBLE_VALUE;
+  }
+
+  template <>
+  bool Has<bool>() const {
+    return type_ == TYPE_BOOL_VALUE;
+  }
+
+  template <>
+  bool Has<std::string>() const {
+    return type_ == TYPE_STRING_VALUE;
+  }
+
+  template <>
+  bool Has<std::vector<std::string>>() const {
+    return type_ == TYPE_STRING_VECTOR_VALUE;
+  }
+
+  template <>
+  bool Has<std::vector<float>>() const {
+    return type_ == TYPE_FLOAT_VECTOR_VALUE;
+  }
+
+  template <>
+  bool Has<std::vector<int>>() const {
+    return type_ == TYPE_INT_VECTOR_VALUE;
+  }
+
+  template <>
+  bool Has<std::map<std::string, Variant>>() const {
+    return type_ == TYPE_STRING_VARIANT_MAP_VALUE;
+  }
+
   // Converts the value of this variant to its string representation, regardless
   // of the type of the actual value.
   std::string ToString() const;
 
-  bool HasInt8() const { return type_ == TYPE_INT8_VALUE; }
-
-  bool HasUInt8() const { return type_ == TYPE_UINT8_VALUE; }
-
-  bool HasInt() const { return type_ == TYPE_INT_VALUE; }
-
-  bool HasUInt() const { return type_ == TYPE_UINT_VALUE; }
-
-  bool HasInt64() const { return type_ == TYPE_INT64_VALUE; }
-
-  bool HasUInt64() const { return type_ == TYPE_UINT64_VALUE; }
-
-  bool HasFloat() const { return type_ == TYPE_FLOAT_VALUE; }
-
-  bool HasDouble() const { return type_ == TYPE_DOUBLE_VALUE; }
-
-  bool HasBool() const { return type_ == TYPE_BOOL_VALUE; }
-
-  bool HasString() const { return type_ == TYPE_STRING_VALUE; }
-
-  bool HasStringVector() const { return type_ == TYPE_STRING_VECTOR_VALUE; }
-
-  bool HasFloatVector() const { return type_ == TYPE_FLOAT_VECTOR_VALUE; }
-
-  bool HasIntVector() const { return type_ == TYPE_INT_VECTOR_VALUE; }
-
-  bool HasStringVariantMap() const {
-    return type_ == TYPE_STRING_VARIANT_MAP_VALUE;
-  }
-
   Type GetType() const { return type_; }
 
   bool HasValue() const { return type_ != TYPE_EMPTY; }