Add a Lookup layer.
PiperOrigin-RevId: 292968344
Change-Id: I7dadf31138366d5b0ad8ed7eb0a885647d25bc81
diff --git a/tensorflow/python/keras/layers/preprocessing/BUILD b/tensorflow/python/keras/layers/preprocessing/BUILD
index d416af6..999f3b8 100644
--- a/tensorflow/python/keras/layers/preprocessing/BUILD
+++ b/tensorflow/python/keras/layers/preprocessing/BUILD
@@ -67,6 +67,31 @@
)
py_library(
+ name = "index_lookup",
+ srcs = [
+ "index_lookup.py",
+ "index_lookup_v1.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_spec",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/keras:backend",
+ "//tensorflow/python/keras/engine:base_preprocessing_layer",
+ "//tensorflow/python/ops/ragged",
+ ],
+)
+
+py_library(
name = "normalization",
srcs = [
"normalization.py",
@@ -93,6 +118,7 @@
srcs_version = "PY2AND3",
deps = [
":categorical_encoding",
+ ":index_lookup",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
@@ -169,6 +195,22 @@
],
)
+tf_py_test(
+ name = "index_lookup_test",
+ size = "medium",
+ srcs = ["index_lookup_test.py"],
+ python_version = "PY3",
+ deps = [
+ ":index_lookup",
+ ":preprocessing_test_utils",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/keras",
+ "//tensorflow/python/keras/utils:generic_utils",
+ "//tensorflow/python/ops/ragged:ragged_string_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
cuda_py_test(
name = "image_preprocessing_test",
size = "medium",
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
new file mode 100644
index 0000000..1fb4b6c
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py
@@ -0,0 +1,402 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras text vectorization preprocessing layer."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import json
+import operator
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
+from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops.ragged import ragged_functional_ops
+from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.util import compat
+
+# The string tokens in the extracted vocabulary
+_VOCAB_NAME = "vocab"
+
+# The string tokens in the full vocabulary
+_ACCUMULATOR_VOCAB_NAME = "vocab"
+# The total counts of each token in the vocabulary
+_ACCUMULATOR_COUNTS_NAME = "counts"
+
+
+class IndexLookup(CombinerPreprocessingLayer):
+ """Maps strings (or integers) from a vocabulary to integer indices.
+
+ This layer translates a set of arbitray strings or integers into an integer
+ output via a table-based lookup, with optional out-of-vocabulary handling.
+
+ If desired, the user can call this layer's `adapt()` method on a data set,
+ which will analyze the data set, determine the frequency of individual string
+ or integer values, and create a vocabulary from them. This vocabulary can have
+ unlimited size or be capped, depending on the configuration options for this
+ layer; if there are more unique values in the input than the maximum
+ vocabulary size, the most frequent terms will be used to create the
+ vocabulary.
+
+ Attributes:
+ max_tokens: The maximum size of the vocabulary for this layer. If None,
+ there is no cap on the size of the vocabulary. Note that the vocabulary
+ does include OOV buckets, so the effective number of unique values in the
+ vocabulary is `(max_tokens - num_oov_tokens)` when this value is set.
+ num_oov_tokens: The number of out-of-vocabulary tokens to use; defaults to
+ 1. If this value is more than 1, OOV inputs are hashed to determine their
+ OOV value; if this value is 0, passing an OOV input will result in a
+ runtime error.
+ reserve_zero: Whether to reserve the index 0, which indicates pad values in
+ the Keras masking system. If True, the output of this layer will be in the
+ range `[1...max_tokens+1)`; if False, the output will be in the range
+ `[0...max_tokens)`. Defaults to True.
+ mask_zero: If True, input values of 0 (for integers) and `""` (for strings)
+ will be treated as masked values and assigned an output value of 0. If
+ this option is set, `reserve_zero` must also be set. Defaults to False.
+ """
+ # TODO(momernick): Add an examples section to the docstring.
+
+ def __init__(self,
+ max_tokens,
+ num_oov_tokens=1,
+ reserve_zero=True,
+ mask_zero=False,
+ **kwargs):
+ allowed_dtypes = [dtypes.string, dtypes.int64]
+ if "dtype" in kwargs and kwargs["dtype"] not in allowed_dtypes:
+ raise ValueError(
+ "TextVectorization may only have a dtype of string or int64.")
+ elif "dtype" not in kwargs:
+ kwargs["dtype"] = dtypes.string
+
+ # If max_tokens is set, the value must be greater than 1 - otherwise we
+ # are creating a 0-element vocab, which doesn't make sense.
+ if max_tokens is not None and max_tokens <= 1:
+ raise ValueError("max_tokens must be greater than 1.")
+
+ # For now, limit the num_oov_tokens to one.
+ if num_oov_tokens != 1:
+ raise ValueError("num_oov_tokens must be 1 for the time being. Other "
+ "values will be supported in the near future. "
+ "You passed %s" % num_oov_tokens)
+
+ self.max_tokens = max_tokens
+ self.num_oov_tokens = num_oov_tokens
+ self.reserve_zero = reserve_zero
+ self.mask_zero = mask_zero
+
+ # We need to reserve at least num_oov_tokens tokens, plus one additional
+ # value if we are reserving the zero value in our output.
+ if reserve_zero:
+ self._reserved_values = (num_oov_tokens + 1)
+ else:
+ self._reserved_values = num_oov_tokens
+
+ # We need to account for the OOV buckets in our vocabulary size.
+ if max_tokens is not None:
+ self._max_elements = max_tokens - num_oov_tokens
+ else:
+ self._max_elements = None
+
+ # If there is only one OOV bucket, we can determine the OOV value (either 0
+ # or 1 depending on whether 0 is reserved) and set that as the default
+ # value of the index_lookup table. If we hav multiple OOV values, we need to
+ # do a further hashing step; to make this easier, we set the OOV value to
+ # -1. (This lets us do a vectorized add and cast to boolean to determine
+ # locations where we need to do extra hashing.)
+ if self.num_oov_tokens == 1:
+ self._oov_value = 1 if reserve_zero else 0
+ else:
+ self._oov_value = -1
+
+ super(IndexLookup, self).__init__(
+ combiner=_IndexLookupCombiner(self.max_tokens), **kwargs)
+
+ # This layer supports RaggedTensor inputs.
+ self._supports_ragged_inputs = True
+
+ # If the layer's input type is int32, we can only output int32 values -
+ # MutableHashTable doesn't allow us to map int32->int64.
+ if self.dtype == dtypes.int32:
+ self._output_dtype = dtypes.int32
+ else:
+ self._output_dtype = dtypes.int64
+
+ self._table = lookup_ops.MutableHashTable(
+ key_dtype=self.dtype,
+ value_dtype=self._output_dtype,
+ default_value=self._oov_value,
+ name=(self._name + "_index_table"))
+
+ # This is a workaround for saving not working yet for MutableHashTables.
+ # By replacing the existing function call by an explicit failure, we
+ # can provide a more user-friendly error message.
+ def fail(_):
+ raise NotImplementedError(
+ "Saving is not yet supported for IndexLookup layers.")
+
+ self._table._list_extra_dependencies_for_serialization = fail # pylint: disable=protected-access
+
+ tracked_table = self._add_trackable(self._table, trainable=False)
+ # This is a workaround for summary() on this layer. Because the table is
+ # not mutable during training, the effective number of parameters (and so
+ # the weight shape) is 0; we add this as an attr so that the parameter
+ # counting code in the Model object doesn't throw an attribute error.
+ tracked_table.shape = tensor_shape.TensorShape((0,))
+
+ def _get_table_data(self):
+ keys, values = self._table.export()
+ return (keys.numpy(), values.numpy())
+
+ def vocab_size(self):
+ return self._table.size().numpy()
+
+ def _clear_table(self):
+ keys, _ = self._table.export()
+ self._table.remove(keys)
+
+ def _insert_table_data(self, keys, values):
+ if len(values) != len(keys):
+ raise RuntimeError("Size mismatch between values and key arrays. "
+ "Keys had size %s, values had size %s." %
+ (len(keys), len(values)))
+ self._table.insert(keys, values)
+
+ def _to_numpy(self, preprocessed_data):
+ """Converts preprocessed inputs into numpy arrays."""
+ if isinstance(preprocessed_data, np.ndarray):
+ return preprocessed_data
+ return np.array(preprocessed_data.to_list())
+ # End of V1/V2 shim points.
+
+ def _assert_same_type(self, expected_type, values, value_name):
+ if dtypes.as_dtype(expected_type) != dtypes.as_dtype(values.dtype):
+ raise RuntimeError("Expected %s type %s, got %s" %
+ (value_name, expected_type, values.dtype))
+
+ def _convert_to_ndarray(self, x):
+ return np.array(x) if isinstance(x, (list, tuple)) else x
+
+ def compute_output_shape(self, input_shape):
+ return input_shape
+
+ def compute_output_signature(self, input_spec):
+ output_shape = self.compute_output_shape(input_spec.shape.as_list())
+ output_dtype = dtypes.int64
+ return tensor_spec.TensorSpec(shape=output_shape, dtype=output_dtype)
+
+ def adapt(self, data, reset_state=True):
+ """Fits the state of the preprocessing layer to the dataset.
+
+ Overrides the default adapt method to apply relevant preprocessing to the
+ inputs before passing to the combiner.
+
+ Arguments:
+ data: The data to train on. It can be passed either as a tf.data Dataset,
+ or as a numpy array.
+ reset_state: Optional argument specifying whether to clear the state of
+ the layer at the start of the call to `adapt`. This must be True for
+ this layer, which does not support repeated calls to `adapt`.
+ """
+ if not reset_state:
+ raise ValueError("IndexLookup does not support streaming adapts.")
+ super(IndexLookup, self).adapt(data, reset_state)
+
+ def get_vocabulary(self):
+ if self.vocab_size() == 0:
+ return []
+
+ keys, values = self._get_table_data()
+ # This is required because the MutableHashTable doesn't preserve insertion
+ # order, but we rely on the order of the array to assign indices.
+ return [x for _, x in sorted(zip(values, keys))]
+
+ def get_config(self):
+ config = {
+ "max_tokens": self.max_tokens,
+ "num_oov_tokens": self.num_oov_tokens,
+ "reserve_zero": self.reserve_zero,
+ "mask_zero": self.mask_zero
+ }
+ base_config = super(IndexLookup, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+ def count_params(self):
+ # This method counts the number of scalars in the weights of this layer.
+ # Since this layer doesn't have any /actual/ weights (in that there's
+ # nothing in this layer that can be trained - we only use the weight
+ # abstraction for ease of saving!) we return 0.
+ return 0
+
+ def set_vocabulary(self,
+ vocab,
+ append=False):
+ """Sets vocabulary (and optionally document frequency) data for this layer.
+
+ This method sets the vocabulary for this layer directly, instead of
+ analyzing a dataset through 'adapt'. It should be used whenever the vocab
+ information is already known. If vocabulary data is already present in the
+ layer, this method will either replace it, if 'append' is set to False, or
+ append to it (if 'append' is set to True).
+
+ Arguments:
+ vocab: An array of string tokens.
+ append: Whether to overwrite or append any existing vocabulary data.
+
+ Raises:
+ ValueError: If there are too many inputs, the inputs do not match, or
+ input data is missing.
+ """
+ current_table_size = self.vocab_size()
+ total_vocab_size = len(vocab) + (current_table_size if append else 0)
+ if self.max_tokens is not None and total_vocab_size > self._max_elements:
+ raise ValueError(
+ "Attempted to set a vocabulary larger than the maximum vocab size. "
+ "Passed vocab size is %s, max vocab size is %s. Note that the OOV "
+ "token(s) are automatically added to the number of tokens." %
+ (total_vocab_size, self.max_tokens))
+
+ start_index = self._reserved_values + (self.vocab_size() if append else 0)
+ values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64)
+
+ vocab = self._convert_to_ndarray(vocab)
+ self._assert_same_type(self.dtype, vocab, "vocab")
+
+ values = self._convert_to_ndarray(values)
+ self._assert_same_type(self._output_dtype, values, "values")
+
+ if not append and self.vocab_size() > 0:
+ self._clear_table()
+ self._insert_table_data(vocab, values)
+
+ def _set_state_variables(self, updates):
+ if not self.built:
+ raise RuntimeError("_set_state_variables() must be called after build().")
+ self.set_vocabulary(updates[_VOCAB_NAME])
+
+ def call(self, inputs):
+ # The table lookup ops don't natively support ragged tensors, so if we have
+ # a RT we need to use map_flat_values to look up every element.
+ if ragged_tensor.is_ragged(inputs):
+ indexed_data = ragged_functional_ops.map_flat_values(
+ self._table.lookup, inputs)
+ else:
+ indexed_data = self._table.lookup(inputs)
+
+ return indexed_data
+
+
+class _IndexLookupCombiner(Combiner):
+ """Combiner for the IndexLookup preprocessing layer.
+
+ This class encapsulates the logic for computing a vocabulary based on the
+ frequency of each token.
+
+ Attributes:
+ vocab_size: (Optional) If set, only the top `vocab_size` tokens (based on
+ frequency across the dataset) are retained in the vocabulary. If None, or
+ set to a value greater than the total number of distinct tokens in the
+ dataset, all tokens are retained.
+ """
+ ACCUMULATOR_CLS = collections.namedtuple("Accumulator", ["count_dict"])
+
+ def __init__(self, vocab_size=None):
+ self._vocab_size = vocab_size
+
+ def compute(self, values, accumulator=None):
+ """Compute a step in this computation, returning a new accumulator."""
+ if ragged_tensor.is_ragged(values):
+ values = values.to_list()
+ if isinstance(values, ops.EagerTensor):
+ values = values.numpy()
+ if isinstance(values, np.ndarray):
+ values = values.tolist()
+
+ if accumulator is None:
+ accumulator = self._create_accumulator()
+
+ # TODO(momernick): Benchmark improvements to this algorithm.
+ for document in values:
+ for token in document:
+ accumulator.count_dict[token] += 1
+
+ return accumulator
+
+ def merge(self, accumulators):
+ """Merge several accumulators to a single accumulator."""
+ if not accumulators:
+ return accumulators
+
+ base_accumulator = accumulators[0]
+ for accumulator in accumulators[1:]:
+ for token, value in accumulator.count_dict.items():
+ base_accumulator.count_dict[token] += value
+
+ return base_accumulator
+
+ def extract(self, accumulator):
+ """Convert an accumulator into a dict of output values.
+
+ Args:
+ accumulator: An accumulator aggregating over the full dataset.
+
+ Returns:
+ A dict of:
+ "vocab": A list of the retained items in the vocabulary.
+ """
+ vocab_counts = accumulator.count_dict
+ sorted_counts = sorted(
+ vocab_counts.items(), key=operator.itemgetter(1, 0), reverse=True)
+ vocab_data = (
+ sorted_counts[:self._vocab_size] if self._vocab_size else sorted_counts)
+ vocab = [data[0] for data in vocab_data]
+ return {_VOCAB_NAME: vocab}
+
+ def restore(self, output):
+ """Create an accumulator based on 'output'."""
+ raise NotImplementedError(
+ "IndexLookup does not restore or support streaming updates.")
+
+ def serialize(self, accumulator):
+ """Serialize an accumulator for a remote call."""
+ output_dict = {}
+ output_dict["vocab"] = list(accumulator.count_dict.keys())
+ output_dict["vocab_counts"] = list(accumulator.count_dict.values())
+ return compat.as_bytes(json.dumps(output_dict))
+
+ def deserialize(self, encoded_accumulator):
+ """Deserialize an accumulator received from 'serialize()'."""
+ accumulator_dict = json.loads(compat.as_text(encoded_accumulator))
+
+ accumulator = self._create_accumulator()
+ count_dict = dict(
+ zip(accumulator_dict["vocab"], accumulator_dict["vocab_counts"]))
+ accumulator.count_dict.update(count_dict)
+
+ return accumulator
+
+ def _create_accumulator(self):
+ """Accumulate a sorted array of vocab tokens and corresponding counts."""
+
+ count_dict = collections.defaultdict(int)
+ return self.ACCUMULATOR_CLS(count_dict)
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py
new file mode 100644
index 0000000..67bbe80
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py
@@ -0,0 +1,481 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Keras text vectorization preprocessing layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python import keras
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras.layers.preprocessing import index_lookup
+from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
+from tensorflow.python.keras.layers.preprocessing import preprocessing_test_utils
+from tensorflow.python.keras.saving import saved_model_experimental as saving
+from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
+from tensorflow.python.platform import test
+
+
+def get_layer_class():
+ if context.executing_eagerly():
+ return index_lookup.IndexLookup
+ else:
+ return index_lookup_v1.IndexLookup
+
+
+def _get_end_to_end_test_cases():
+ test_cases = (
+ {
+ "testcase_name":
+ "test_strings_soft_vocab_cap",
+ # Create an array where 'earth' is the most frequent term, followed by
+ # 'wind', then 'and', then 'fire'. This ensures that the vocab
+ # accumulator is sorting by frequency.
+ "vocab_data":
+ np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"],
+ ["wind"], ["wind"], ["wind"], ["and"], ["and"]]),
+ "input_data":
+ np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"],
+ ["and"], ["earth"], ["michigan"]]),
+ "kwargs": {
+ "max_tokens": None,
+ },
+ "expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
+ "input_dtype":
+ dtypes.string
+ },
+ {
+ "testcase_name":
+ "test_ints_soft_vocab_cap",
+ # Create an array where 1138 is the most frequent term, followed by
+ # 1729, then 725, then 42. This ensures that the vocab accumulator
+ # is sorting by frequency.
+ "vocab_data":
+ np.array([[42], [1138], [1138], [1138], [1138], [1729], [1729],
+ [1729], [725], [725]]),
+ "input_data":
+ np.array([[1138], [1729], [725], [42], [42], [725], [1138], [4]]),
+ "kwargs": {
+ "max_tokens": None,
+ "dtype": dtypes.int64,
+ },
+ "expected_output": [[2], [3], [4], [5], [5], [4], [2], [1]],
+ "input_dtype":
+ dtypes.int64
+ },
+ )
+
+ crossed_test_cases = []
+ # Cross above test cases with use_dataset in (True, False)
+ for use_dataset in (True, False):
+ for case in test_cases:
+ case = case.copy()
+ if use_dataset:
+ case["testcase_name"] = case["testcase_name"] + "_with_dataset"
+ case["use_dataset"] = use_dataset
+ crossed_test_cases.append(case)
+
+ return crossed_test_cases
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupLayerTest(keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ @parameterized.named_parameters(*_get_end_to_end_test_cases())
+ def test_layer_end_to_end_with_adapt(self, vocab_data, input_data, kwargs,
+ use_dataset, expected_output,
+ input_dtype):
+ cls = get_layer_class()
+ expected_output_dtype = dtypes.int64
+ input_shape = input_data.shape
+
+ if use_dataset:
+ # Keras APIs expect batched datasets.
+ # TODO(rachelim): `model.predict` predicts the result on each
+ # dataset batch separately, then tries to concatenate the results
+ # together. When the results have different shapes on the non-concat
+ # axis (which can happen in the output_mode = INT case for
+ # IndexLookup), the concatenation fails. In real use cases, this may
+ # not be an issue because users are likely to pipe the preprocessing layer
+ # into other keras layers instead of predicting it directly. A workaround
+ # for these unit tests is to have the dataset only contain one batch, so
+ # no concatenation needs to happen with the result. For consistency with
+ # numpy input, we should make `predict` join differently shaped results
+ # together sensibly, with 0 padding.
+ input_data = dataset_ops.Dataset.from_tensor_slices(input_data).batch(
+ input_shape[0])
+ vocab_data = dataset_ops.Dataset.from_tensor_slices(vocab_data).batch(
+ input_shape[0])
+
+ with CustomObjectScope({"IndexLookup": cls}):
+ output_data = testing_utils.layer_test(
+ cls,
+ kwargs=kwargs,
+ input_shape=input_shape,
+ input_data=input_data,
+ input_dtype=input_dtype,
+ expected_output_dtype=expected_output_dtype,
+ validate_training=False,
+ adapt_data=vocab_data)
+ self.assertAllClose(expected_output, output_data)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupOutputTest(keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ def test_int_output(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "and", "earth", "michigan"]])
+ expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
+
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(max_tokens=None)
+ layer.set_vocabulary(vocab_data)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ output_dataset = model.predict(input_array)
+ self.assertAllEqual(expected_output, output_dataset)
+
+ def test_int_output_no_reserved_zero(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "and", "earth", "michigan"]])
+ expected_output = [[1, 2, 3, 4], [4, 3, 1, 0]]
+
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(max_tokens=None, reserve_zero=False)
+ layer.set_vocabulary(vocab_data)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ output_dataset = model.predict(input_array)
+ self.assertAllEqual(expected_output, output_dataset)
+
+ def test_vocab_appending(self):
+ vocab_data = [["earth", "wind"], ["and", "fire"]]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "and", "earth", "michigan"]])
+ expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
+
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(max_tokens=5)
+ layer.set_vocabulary(vocab_data[0])
+ layer.set_vocabulary(vocab_data[1], append=True)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ output_dataset = model.predict(input_array)
+ self.assertAllClose(expected_output, output_dataset)
+
+
+@keras_parameterized.run_all_keras_modes(always_skip_eager=True)
+class IndexLookupSaveableTest(keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ def test_ops_are_not_added_with_multiple_get_set_weights(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(max_tokens=10)
+ layer.set_vocabulary(vocab_data)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ weights = model.get_weights()
+ model.set_weights(weights)
+ keras.backend.get_session().graph.finalize()
+ weights = model.get_weights()
+ model.set_weights(weights)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupErrorTest(keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ def test_too_long_vocab_fails_in_single_setting(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+
+ layer = get_layer_class()(max_tokens=4)
+ with self.assertRaisesRegex(ValueError,
+ "vocabulary larger than the maximum vocab.*"):
+ layer.set_vocabulary(vocab_data)
+
+ def test_too_long_vocab_fails_in_multiple_settings(self):
+ vocab_data = [["earth", "wind"], ["and", "fire"]]
+ layer = get_layer_class()(max_tokens=4)
+
+ # The first time we call set_vocabulary, we're under the max_tokens
+ # so it should be fine.
+ layer.set_vocabulary(vocab_data[0])
+ with self.assertRaisesRegex(ValueError,
+ "vocabulary larger than the maximum vocab.*"):
+ layer.set_vocabulary(vocab_data[1], append=True)
+
+ def test_zero_max_tokens_fails(self):
+ with self.assertRaisesRegex(ValueError, ".*max_tokens.*"):
+ _ = get_layer_class()(max_tokens=0)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupSavingTest(keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ def test_saving_errors(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+
+ # Build and validate a golden model.
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(max_tokens=None)
+ layer.set_vocabulary(vocab_data)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+
+ # Save the model to disk.
+ output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+
+ with self.assertRaisesRegex(NotImplementedError, ".*Saving is not yet.*"):
+ model.save(output_path, save_format="tf")
+
+ def test_saving_errors_when_nested(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+
+ # Build and validate a golden model.
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(max_tokens=None)
+ layer.set_vocabulary(vocab_data)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+
+ outer_input = keras.Input(shape=(None,), dtype=dtypes.string)
+ outer_output = model(outer_input)
+ outer_model = keras.Model(inputs=outer_input, outputs=outer_output)
+
+ # Save the model to disk.
+ output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+
+ with self.assertRaisesRegex(NotImplementedError, ".*Saving is not yet.*"):
+ outer_model.save(output_path, save_format="tf")
+
+ def DISABLED_test_vocabulary_persistence_across_saving(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ input_array = np.array([["earth", "wind", "and", "fire"],
+ ["fire", "and", "earth", "michigan"]])
+ expected_output = [[2, 3, 4, 5], [5, 4, 2, 1]]
+
+ # Build and validate a golden model.
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(max_tokens=None)
+ layer.set_vocabulary(vocab_data)
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ output_dataset = model.predict(input_array)
+ self.assertAllEqual(output_dataset, expected_output)
+
+ # Save the model to disk.
+ output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+ model.save(output_path, save_format="tf")
+ loaded_model = saving.load_from_saved_model(
+ output_path, custom_objects={"IndexLookup": get_layer_class()})
+
+ # Ensure that the loaded model is unique (so that the save/load is real)
+ self.assertIsNot(model, loaded_model)
+
+ # Validate correctness of the new model.
+ new_output_dataset = loaded_model.predict(input_array)
+ self.assertAllEqual(new_output_dataset, expected_output)
+
+ def DISABLED_test_vocabulary_persistence_across_saving_with_tfidf(self):
+ vocab_data = ["earth", "wind", "and", "fire"]
+ tfidf_data = [.5, .25, .2, .125]
+ input_array = np.array([["earth", "wind", "and", "earth"],
+ ["ohio", "fire", "earth", "michigan"]])
+
+ # pyformat: disable
+ # pylint: disable=bad-whitespace
+ expected_output = [[ 0, 1, .25, .2, 0],
+ [.1, .5, 0, 0, .125]]
+ # pylint: enable=bad-whitespace
+ # pyformat: enable
+
+ # Build and validate a golden model.
+ input_data = keras.Input(shape=(None,), dtype=dtypes.string)
+ layer = get_layer_class()(
+ max_tokens=5,
+ standardize=None,
+ split=None,
+ output_mode=index_lookup.TFIDF)
+ layer.set_vocabulary(vocab_data, df_data=tfidf_data, oov_df_value=.05)
+
+ int_data = layer(input_data)
+ model = keras.Model(inputs=input_data, outputs=int_data)
+ output_dataset = model.predict(input_array)
+ self.assertAllClose(output_dataset, expected_output)
+
+ # Save the model to disk.
+ output_path = os.path.join(self.get_temp_dir(), "tf_keras_saved_model")
+ model.save(output_path, save_format="tf")
+ loaded_model = saving.load_from_saved_model(
+ output_path, custom_objects={"IndexLookup": get_layer_class()})
+
+ # Ensure that the loaded model is unique (so that the save/load is real)
+ self.assertIsNot(model, loaded_model)
+
+ # Validate correctness of the new model.
+ new_output_dataset = loaded_model.predict(input_array)
+ self.assertAllClose(new_output_dataset, expected_output)
+
+
+@keras_parameterized.run_all_keras_modes
+class IndexLookupCombinerTest(keras_parameterized.TestCase,
+ preprocessing_test_utils.PreprocessingLayerTest):
+
+ def compare_text_accumulators(self, a, b, msg=None):
+ if a is None or b is None:
+ self.assertAllEqual(a, b, msg=msg)
+
+ self.assertAllEqual(a.count_dict, b.count_dict, msg=msg)
+
+ compare_accumulators = compare_text_accumulators
+
+ def update_accumulator(self, accumulator, data):
+ accumulator.count_dict.update(dict(zip(data["vocab"], data["counts"])))
+
+ return accumulator
+
+ def test_combiner_api_compatibility_int_mode(self):
+ data = np.array([["earth", "wind", "and", "fire"],
+ ["earth", "wind", "and", "michigan"]])
+ combiner = index_lookup._IndexLookupCombiner()
+ expected_accumulator_output = {
+ "vocab": np.array(["and", "earth", "wind", "fire", "michigan"]),
+ "counts": np.array([2, 2, 2, 1, 1]),
+ }
+ expected_extract_output = {
+ "vocab": np.array(["wind", "earth", "and", "michigan", "fire"]),
+ }
+ expected_accumulator = combiner._create_accumulator()
+ expected_accumulator = self.update_accumulator(expected_accumulator,
+ expected_accumulator_output)
+ self.validate_accumulator_serialize_and_deserialize(combiner, data,
+ expected_accumulator)
+ self.validate_accumulator_uniqueness(combiner, data)
+ self.validate_accumulator_extract(combiner, data, expected_extract_output)
+
+ # TODO(askerryryan): Add tests confirming equivalence to behavior of
+ # existing tf.keras.preprocessing.text.Tokenizer.
+ @parameterized.named_parameters(
+ {
+ "testcase_name":
+ "top_k_smaller_than_full_vocab",
+ "data":
+ np.array([["earth", "wind"], ["fire", "wind"], ["and"],
+ ["fire", "wind"]]),
+ "vocab_size":
+ 3,
+ "expected_accumulator_output": {
+ "vocab": np.array(["wind", "fire", "earth", "and"]),
+ "counts": np.array([3, 2, 1, 1]),
+ },
+ "expected_extract_output": {
+ "vocab": np.array(["wind", "fire", "earth"]),
+ },
+ },
+ {
+ "testcase_name":
+ "top_k_larger_than_full_vocab",
+ "data":
+ np.array([["earth", "wind"], ["fire", "wind"], ["and"],
+ ["fire", "wind"]]),
+ "vocab_size":
+ 10,
+ "expected_accumulator_output": {
+ "vocab": np.array(["wind", "fire", "earth", "and"]),
+ "counts": np.array([3, 2, 1, 1]),
+ },
+ "expected_extract_output": {
+ "vocab": np.array(["wind", "fire", "earth", "and"]),
+ },
+ },
+ {
+ "testcase_name":
+ "no_top_k",
+ "data":
+ np.array([["earth", "wind"], ["fire", "wind"], ["and"],
+ ["fire", "wind"]]),
+ "vocab_size":
+ None,
+ "expected_accumulator_output": {
+ "vocab": np.array(["wind", "fire", "earth", "and"]),
+ "counts": np.array([3, 2, 1, 1]),
+ },
+ "expected_extract_output": {
+ "vocab": np.array(["wind", "fire", "earth", "and"]),
+ },
+ },
+ {
+ "testcase_name": "single_element_per_row",
+ "data": np.array([["earth"], ["wind"], ["fire"], ["wind"], ["and"]]),
+ "vocab_size": 3,
+ "expected_accumulator_output": {
+ "vocab": np.array(["wind", "and", "earth", "fire"]),
+ "counts": np.array([2, 1, 1, 1]),
+ },
+ "expected_extract_output": {
+ "vocab": np.array(["wind", "fire", "earth"]),
+ },
+ },
+ # Which tokens are retained are based on global frequency, and thus are
+ # sensitive to frequency within a document. In contrast, because idf only
+ # considers the presence of a token in a document, it is insensitive
+ # to the frequency of the token within the document.
+ {
+ "testcase_name":
+ "retained_tokens_sensitive_to_within_document_frequency",
+ "data":
+ np.array([["earth", "earth"], ["wind", "wind"], ["fire", "fire"],
+ ["wind", "wind"], ["and", "michigan"]]),
+ "vocab_size":
+ 3,
+ "expected_accumulator_output": {
+ "vocab": np.array(["wind", "earth", "fire", "and", "michigan"]),
+ "counts": np.array([4, 2, 2, 1, 1]),
+ },
+ "expected_extract_output": {
+ "vocab": np.array(["wind", "fire", "earth"]),
+ },
+ })
+ def test_combiner_computation(self, data, vocab_size,
+ expected_accumulator_output,
+ expected_extract_output):
+ combiner = index_lookup._IndexLookupCombiner(vocab_size=vocab_size)
+ expected_accumulator = combiner._create_accumulator()
+ expected_accumulator = self.update_accumulator(expected_accumulator,
+ expected_accumulator_output)
+ self.validate_accumulator_computation(combiner, data, expected_accumulator)
+ self.validate_accumulator_extract(combiner, data, expected_extract_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py
new file mode 100644
index 0000000..cb5691a
--- /dev/null
+++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_v1.py
@@ -0,0 +1,86 @@
+# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tensorflow V1 version of the text vectorization preprocessing layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.keras import backend as K
+from tensorflow.python.keras.engine import base_preprocessing_layer_v1
+from tensorflow.python.keras.layers.preprocessing import index_lookup
+from tensorflow.python.ops.ragged import ragged_tensor_value
+
+
+class IndexLookup(index_lookup.IndexLookup,
+ base_preprocessing_layer_v1.CombinerPreprocessingLayer):
+ """IndexLookup layer.
+
+ This layer translates a set of arbitray strings or integers into an integer
+ output via a table-based lookup, with optional out-of-vocabulary handling.
+
+ If desired, the user can call this layer's adapt() method on a data set.
+ When this layer is adapted, it will analyze the dataset, determine the
+ frequency of individual string or integer values, and create a vocabulary
+ from them. This vocabulary can have unlimited size or be capped, depending on
+ the configuration options for this layer; if there are more unique values in
+ the input than the maximum vocabulary size, the most frequent terms will be
+ used to create the vocabulary.
+
+ Attributes:
+ max_vocab_size: The maximum size of the vocabulary for this layer. If None,
+ there is no cap on the size of the vocabulary. Note that the vocabulary
+ does include OOV buckets, so the effective number of unique values in the
+ vocabulary is (max_vocab_size - num_oov_buckets) when this value is set.
+ num_oov_buckets: The number of out-of-vocabulary tokens to use; defaults to
+ 1. If this value is more than 1, OOV inputs are hashed to determine their
+ OOV value; if this value is 0, passing an OOV input will result in a
+ runtime error.
+ reserve_zero: Whether to reserve the index '0', which has a special meaning
+ in the Keras masking system. If True, the output of this layer will be in
+ the range [1...max_vocab_size+1); if False, the output will be in the
+ range [0...max_vocab_size). Defaults to True.
+ mask_inputs: If True, input values of 0 (for integers) and "" (for strings)
+ will be treated as masked values and assigned an output value of 0. If
+ this option is set, reserve_zero must also be set. Defaults to False.
+ """
+
+ def _get_table_data(self):
+ keys, values = self._table.export()
+ np_keys = K.get_session().run(keys)
+ np_values = K.get_session().run(values)
+ return (np_keys, np_values)
+
+ def vocab_size(self):
+ return K.get_session().run(self._table.size())
+
+ def _clear_table(self):
+ keys, _ = self._table.export()
+ K.get_session().run(self._table.remove(keys))
+
+ def _insert_table_data(self, keys, values):
+ K.get_session().run(self._table.insert(keys, values))
+
+ def _to_numpy(self, data):
+ """Converts preprocessed inputs into numpy arrays."""
+ if isinstance(data, np.ndarray):
+ return data
+ session = K.get_session()
+ data = session.run(data)
+ if isinstance(data, ragged_tensor_value.RaggedTensorValue):
+ data = np.array(data.to_list())
+ return data
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
index a315df0..64fa210 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py
@@ -32,13 +32,12 @@
from tensorflow.python.keras.engine.base_preprocessing_layer import Combiner
from tensorflow.python.keras.engine.base_preprocessing_layer import CombinerPreprocessingLayer
from tensorflow.python.keras.layers.preprocessing import categorical_encoding
+from tensorflow.python.keras.layers.preprocessing import index_lookup
from tensorflow.python.keras.utils import layer_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_string_ops
-from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import string_ops
-from tensorflow.python.ops.ragged import ragged_functional_ops
from tensorflow.python.ops.ragged import ragged_string_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.util import compat
@@ -219,7 +218,7 @@
# 'standardize' must be one of (None, LOWER_AND_STRIP_PUNCTUATION, callable)
layer_utils.validate_string_arg(
standardize,
- allowable_strings=[LOWER_AND_STRIP_PUNCTUATION],
+ allowable_strings=(LOWER_AND_STRIP_PUNCTUATION),
layer_name="TextVectorization",
arg_name="standardize",
allow_none=True,
@@ -228,7 +227,7 @@
# 'split' must be one of (None, SPLIT_ON_WHITESPACE, callable)
layer_utils.validate_string_arg(
split,
- allowable_strings=[SPLIT_ON_WHITESPACE],
+ allowable_strings=(SPLIT_ON_WHITESPACE),
layer_name="TextVectorization",
arg_name="split",
allow_none=True,
@@ -237,7 +236,7 @@
# 'output_mode' must be one of (None, INT, COUNT, BINARY, TFIDF)
layer_utils.validate_string_arg(
output_mode,
- allowable_strings=[INT, COUNT, BINARY, TFIDF],
+ allowable_strings=(INT, COUNT, BINARY, TFIDF),
layer_name="TextVectorization",
arg_name="output_mode",
allow_none=True)
@@ -303,24 +302,9 @@
self._max_vocab_size, compute_idf=output_mode == TFIDF),
**kwargs)
- self._table = lookup_ops.MutableHashTable(
- key_dtype=dtypes.string,
- value_dtype=dtypes.int64,
- default_value=self._oov_value,
- name=(self._name + "_index_table"))
-
- def fail(_):
- raise NotImplementedError(
- "Saving is not yet supported for TextVectorization layers.")
- self._table._list_extra_dependencies_for_serialization = fail # pylint: disable=protected-access
-
- tracked_table = self._add_trackable(self._table, trainable=False)
-
- # This is a workaround for summary() on this layer. Because the table is
- # not mutable during training, the effective number of parameters (and so
- # the weight shape) is 0; we add this as an attr so that the parameter
- # counting code in the Model object doesn't throw an attribute error.
- tracked_table.shape = tensor_shape.TensorShape((0,))
+ reserve_zero = output_mode in [None, INT]
+ self._index_lookup_layer = self._get_index_lookup_class()(
+ max_tokens=max_tokens, reserve_zero=reserve_zero, dtype=dtypes.string)
# If this layer is configured for string or integer output, we do not
# create a vectorization layer (as the output is not vectorized).
@@ -328,11 +312,11 @@
return
if max_tokens is not None and self._pad_to_max:
- vectorize_max_tokens = max_tokens
+ max_elements = max_tokens
else:
- vectorize_max_tokens = None
+ max_elements = None
self._vectorize_layer = self._get_vectorization_class()(
- max_tokens=vectorize_max_tokens, output_mode=self._output_mode)
+ max_tokens=max_elements, output_mode=self._output_mode)
# These are V1/V2 shim points. There are V1 implementations in the V1 class.
def _get_vectorization_class(self):
@@ -342,31 +326,8 @@
keys, values = self._table.export()
return (keys.numpy(), values.numpy())
- def _get_table_size(self):
- return self._table.size().numpy()
-
- def _clear_table(self):
- if (self._output_mode in [BINARY, COUNT, TFIDF] and self._called and
- not self._pad_to_max):
- raise RuntimeError(("When using TextVectorization in {mode} mode, the "
- "vocabulary cannot be changed after the layer is "
- "called.").format(mode=self._output_mode))
- keys, _ = self._table.export()
- self._table.remove(keys)
- self._vocab_size = 0
-
- def _insert_table_data(self, keys, values):
- if (self._output_mode in [BINARY, COUNT, TFIDF] and self._called and
- not self._pad_to_max):
- raise RuntimeError(("When using TextVectorization in {mode} mode, the "
- "vocabulary cannot be changed after the layer is "
- "called.").format(mode=self._output_mode))
- if len(values) != len(keys):
- raise RuntimeError("Size mismatch between values and key arrays. "
- "Keys had size %s, values had size %s." %
- (len(keys), len(values)))
- self._table.insert(keys, values)
- self._vocab_size += len(keys)
+ def _get_index_lookup_class(self):
+ return index_lookup.IndexLookup
def _to_numpy(self, preprocessed_data):
"""Converts preprocessed inputs into numpy arrays."""
@@ -441,13 +402,7 @@
super(TextVectorization, self).adapt(preprocessed_inputs, reset_state)
def get_vocabulary(self):
- if self._vocab_size == 0:
- return []
-
- keys, values = self._get_table_data()
- # This is required because the MutableHashTable doesn't preserve insertion
- # order, but we rely on the order of the array to assign indices.
- return [x for _, x in sorted(zip(values, keys))]
+ return self._index_lookup_layer.get_vocabulary()
def get_config(self):
config = {
@@ -496,15 +451,33 @@
Raises:
ValueError: If there are too many inputs, the inputs do not match, or
input data is missing.
+ RuntimeError: If the vocabulary cannot be set when this function is
+ called. This happens when "binary", "count", and "tfidf" modes,
+ if "pad_to_max_tokens" is False and the layer itself has already been
+ called.
"""
- current_table_size = self._get_table_size()
- total_vocab_size = len(vocab) + (current_table_size if append else 0)
- if self._max_tokens is not None and total_vocab_size > self._max_vocab_size:
- raise ValueError(
- "Attempted to set a vocabulary larger than the maximum vocab size. "
- "Passed vocab size is %s, max vocab size is %s. Note that the OOV "
- "token is automatically added to the number of tokens." %
- (total_vocab_size, self._max_vocab_size))
+ if self._output_mode != TFIDF and df_data is not None:
+ raise ValueError("df_data should only be set if output_mode is TFIDF. "
+ "output_mode is %s." % self._output_mode)
+
+ if (self._output_mode in [BINARY, COUNT, TFIDF] and self._called and
+ not self._pad_to_max):
+ raise RuntimeError(("When using TextVectorization in {mode} mode and "
+ "pad_to_max_tokens is False, the vocabulary cannot "
+ "be changed after the layer is "
+ "called.").format(mode=self._output_mode))
+
+ current_table_size = self._index_lookup_layer.vocab_size()
+ self._index_lookup_layer.set_vocabulary(vocab, append)
+
+ # When doing raw or integer output, we don't have a Vectorize layer to
+ # manage. In this case, we can return directly.
+ if self._output_mode in [None, INT]:
+ return
+
+ if not self._pad_to_max or self._max_tokens is None:
+ num_tokens = self._index_lookup_layer.vocab_size() + self._reserved_values
+ self._vectorize_layer.set_num_elements(num_tokens)
# We're only _really_ appending if the table_size is nonzero. This is
# important for some sanity checks in tfidf mode (specifically, checking if
@@ -522,35 +495,7 @@
raise ValueError("You must pass an oov_df_value the first time "
"'set_vocabulary' is called when output_mode is "
"TFIDF.")
- else:
- if df_data is not None:
- raise ValueError("df_data should only be set if output_mode is TFIDF. "
- "output_mode is %s." % self._output_mode)
- start_index = self._reserved_values + (
- self._get_table_size() if append else 0)
- values = np.arange(start_index, len(vocab) + start_index, dtype=np.int64)
-
- vocab = self._convert_to_ndarray(vocab)
- self._assert_same_type(dtypes.string, vocab, "vocab")
-
- values = self._convert_to_ndarray(values)
- self._assert_same_type(dtypes.int64, values, "values")
-
- if not append and self._vocab_size > 0:
- self._clear_table()
- self._insert_table_data(vocab, values)
-
- # When doing raw or integer output, we don't have a Vectorize layer to
- # manage. In this case, we can return directly.
- if self._output_mode in [None, INT]:
- return
-
- if not self._pad_to_max or self._max_tokens is None:
- num_tokens = total_vocab_size + self._reserved_values
- self._vectorize_layer.set_num_elements(num_tokens)
-
- if self._output_mode == TFIDF:
df_data = self._convert_to_ndarray(df_data)
if append:
# The existing IDF data is stored in a Keras weight, so we can get it
@@ -584,9 +529,6 @@
"dimension of the input array must be 1, got shape "
"{}".format(input_shape))
- # This handles a corner case where, if restored from weights or SavedModel,
- # the layer might not have accurate vocab size information.
- self._vocab_size = self._get_table_size()
super(TextVectorization, self).build(input_shape)
def _set_state_variables(self, updates):
@@ -646,13 +588,7 @@
if self._output_mode is None:
return inputs
- # The table lookup ops don't natively support ragged tensors, so if we have
- # a RT we need to use map_flat_values to look up every element.
- if ragged_tensor.is_ragged(inputs):
- indexed_data = ragged_functional_ops.map_flat_values(
- self._table.lookup, inputs)
- else:
- indexed_data = self._table.lookup(inputs)
+ indexed_data = self._index_lookup_layer(inputs)
if self._output_mode == INT:
# Once we have the dense tensor, we can return it if we weren't given a
@@ -687,7 +623,7 @@
# A note on this combiner: This contains functionality that will be extracted
-# into the Vectorization and Lookup combiner objects. At that point,
+# into the Vectorization and IndexLookup combiner objects. At that point,
# TextVectorization can become a PreprocessingStage instead of a Layer and
# this combiner can be retired. Until then, we leave this as is instead of
# attempting a refactor of what will soon be deleted.
diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
index 8c5b7f1..b869bee 100644
--- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
+++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_v1.py
@@ -23,6 +23,7 @@
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine import base_preprocessing_layer_v1
from tensorflow.python.keras.layers.preprocessing import categorical_encoding_v1
+from tensorflow.python.keras.layers.preprocessing import index_lookup_v1
from tensorflow.python.keras.layers.preprocessing import text_vectorization
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.util.tf_export import keras_export
@@ -82,37 +83,8 @@
def _get_vectorization_class(self):
return categorical_encoding_v1.CategoricalEncoding
- def _get_table_data(self):
- keys, values = self._table.export()
- np_keys = K.get_session().run(keys)
- np_values = K.get_session().run(values)
- return (np_keys, np_values)
-
- def _get_table_size(self):
- return K.get_session().run(self._table.size())
-
- def _clear_table(self):
- if (self._output_mode in [
- text_vectorization.BINARY, text_vectorization.COUNT,
- text_vectorization.TFIDF
- ] and self._called and not self._pad_to_max):
- raise RuntimeError(("When using TextVectorization in {mode} mode, the "
- "vocabulary cannot be changed after the layer is "
- "called.").format(mode=self._output_mode))
- keys, _ = self._table.export()
- K.get_session().run(self._table.remove(keys))
- self._vocab_size = 0
-
- def _insert_table_data(self, keys, values):
- if (self._output_mode in [
- text_vectorization.BINARY, text_vectorization.COUNT,
- text_vectorization.TFIDF
- ] and self._called and not self._pad_to_max):
- raise RuntimeError(("When using TextVectorization in {mode} mode, the "
- "vocabulary cannot be changed after the layer is "
- "called.").format(mode=self._output_mode))
- K.get_session().run(self._table.insert(keys, values))
- self._vocab_size += len(keys)
+ def _get_index_lookup_class(self):
+ return index_lookup_v1.IndexLookup
def _to_numpy(self, data):
"""Converts preprocessed inputs into numpy arrays."""
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 41e473f..a386792 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -228,6 +228,7 @@
layer.adapt(adapt_data)
model = keras.models.Sequential()
+ model.add(keras.layers.Input(shape=input_shape[1:], dtype=input_dtype))
model.add(layer)
actual_output = model.predict(input_data)
actual_output_shape = actual_output.shape
diff --git a/tensorflow/python/keras/utils/layer_utils.py b/tensorflow/python/keras/utils/layer_utils.py
index f2689d0..dcb42ab 100644
--- a/tensorflow/python/keras/utils/layer_utils.py
+++ b/tensorflow/python/keras/utils/layer_utils.py
@@ -86,7 +86,7 @@
else:
allowed_args = '`None`, ' if allow_none else ''
allowed_args += 'a `Callable`, ' if allow_callables else ''
- allowed_args += 'or one of the following values: %s' % allowable_strings
+ allowed_args += 'or one of the following values: %s' % (allowable_strings,)
raise ValueError(("%s's %s arg received an invalid value %s. " +
'Allowed values are %s.') %
(layer_name, arg_name, input_data, allowed_args))