From 19ed4a9ccfca2565f130df523e630fedec68728d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 21 May 2020 16:04:08 -0700 Subject: [PATCH] Fix issues where index_lookup was improperly handling hard vocab caps. Add tests. PiperOrigin-RevId: 312759072 Change-Id: Id24687eee01a6898473e128b8c2cfeb13be89547 --- .../layers/preprocessing/index_lookup.py | 9 ++- .../layers/preprocessing/index_lookup_test.py | 75 ++++++++++++++++++- .../preprocessing/text_vectorization_test.py | 34 +++++++++ 3 files changed, 115 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup.py b/tensorflow/python/keras/layers/preprocessing/index_lookup.py index c0d0d266ad3..7d11feae341 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup.py @@ -118,9 +118,14 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer): else: self._oov_value = -1 + if max_tokens is not None: + num_mask_tokens = (0 if mask_token is None else 1) + vocab_size = max_tokens - (num_oov_indices + num_mask_tokens) + else: + vocab_size = None + super(IndexLookup, self).__init__( - combiner=_IndexLookupCombiner(self.max_tokens, self.mask_token), - **kwargs) + combiner=_IndexLookupCombiner(vocab_size, self.mask_token), **kwargs) self._output_dtype = dtypes.int64 diff --git a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py index 73189d9b9f1..a61cef6121f 100644 --- a/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py +++ b/tensorflow/python/keras/layers/preprocessing/index_lookup_test.py @@ -86,7 +86,8 @@ def _get_end_to_end_test_cases(): "vocab_data": np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], ["wind"], ["wind"], ["wind"], ["and"], ["and"]]), - "input_data": np.array([[1], [2], [3], [4], [4], [3], [1], [5]]), + "input_data": + np.array([[1], [2], [3], [4], [4], [3], [1], [5]]), "kwargs": { "max_tokens": None, "num_oov_indices": 1, @@ -125,6 +126,78 @@ def _get_end_to_end_test_cases(): "input_dtype": dtypes.int64 }, + { + "testcase_name": + "test_strings_hard_vocab_cap", + # Create an array where 'earth' is the most frequent term, followed by + # 'wind', then 'and', then 'fire'. This ensures that the vocab + # accumulator is sorting by frequency. + "vocab_data": + np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], + ["wind"], ["wind"], ["wind"], ["and"], ["and"]]), + "input_data": + np.array([["earth"], ["wind"], ["and"], ["fire"], ["fire"], + ["and"], ["earth"], ["michigan"]]), + "kwargs": { + "max_tokens": 5, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "dtype": dtypes.string, + }, + "expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]], + "input_dtype": + dtypes.string + }, + { + "testcase_name": + "test_inverse_strings_hard_vocab_cap", + # Create an array where 'earth' is the most frequent term, followed by + # 'wind', then 'and', then 'fire'. This ensures that the vocab + # accumulator is sorting by frequency. + "vocab_data": + np.array([["fire"], ["earth"], ["earth"], ["earth"], ["earth"], + ["wind"], ["wind"], ["wind"], ["and"], ["and"]]), + "input_data": + np.array([[1], [2], [3], [4], [4], [3], [1], [5]]), + "kwargs": { + "max_tokens": 5, + "num_oov_indices": 1, + "mask_token": "", + "oov_token": "[OOV]", + "dtype": dtypes.string, + "invert": True + }, + "expected_output": + np.array([[b"earth"], [b"wind"], [b"and"], [b"[OOV]"], [b"[OOV]"], + [b"and"], [b"earth"], [b"[OOV]"]]), + "input_dtype": + dtypes.int64 + }, + { + "testcase_name": + "test_ints_hard_vocab_cap", + # Create an array where 1138 is the most frequent term, followed by + # 1729, then 725, then 42. This ensures that the vocab accumulator + # is sorting by frequency. + "vocab_data": + np.array([[42], [1138], [1138], [1138], [1138], [1729], [1729], + [1729], [725], [725]], + dtype=np.int64), + "input_data": + np.array([[1138], [1729], [725], [42], [42], [725], [1138], [4]], + dtype=np.int64), + "kwargs": { + "max_tokens": 5, + "num_oov_indices": 1, + "mask_token": 0, + "oov_token": -1, + "dtype": dtypes.int64, + }, + "expected_output": [[2], [3], [4], [1], [1], [4], [2], [1]], + "input_dtype": + dtypes.int64 + }, ) crossed_test_cases = [] diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py index affa392e42b..5d909498d8a 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py @@ -1510,5 +1510,39 @@ class TextVectorizationSavingTest( self.assertAllEqual(expected_output, new_output_dataset) +@keras_parameterized.run_all_keras_modes +class TextVectorizationE2ETest(keras_parameterized.TestCase, + preprocessing_test_utils.PreprocessingLayerTest): + + def test_keras_vocab_trimming_example(self): + vocab_data = np.array([ + "earth", "earth", "earth", "earth", "wind", "wind", "wind", "and", + "and", "fire" + ]) + input_array = np.array([["earth", "wind", "and", "earth"], + ["ohio", "and", "earth", "michigan"]]) + + # pyformat: disable + expected_output = [[1, 2, 1], + [3, 1, 0]] + # pyformat: enable + max_tokens = 3 + expected_output_shape = [None, max_tokens] + + input_data = keras.Input(shape=(None,), dtype=dtypes.string) + layer = get_layer_class()( + max_tokens=max_tokens, + standardize=None, + split=None, + output_mode=text_vectorization.COUNT, + pad_to_max_tokens=True) + int_data = layer(input_data) + layer.adapt(vocab_data) + self.assertAllEqual(expected_output_shape, int_data.shape.as_list()) + model = keras.Model(input_data, int_data) + output = model.predict(input_array) + self.assertAllEqual(expected_output, output) + + if __name__ == "__main__": test.main()