From a92ff929b818c7dbca2d0c2648ae17e8d6ae3a40 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 26 May 2020 09:31:32 -0700 Subject: [PATCH] Fix dimensionality handling issues in TextVectorization. PiperOrigin-RevId: 313206360 Change-Id: I5929e83b26011c975561e525b90aef6949a185b2 --- .../keras/layers/preprocessing/table_utils.py | 2 + .../preprocessing/text_vectorization.py | 24 ++++---- .../preprocessing/text_vectorization_test.py | 57 ++++++++++++++++++- 3 files changed, 68 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/keras/layers/preprocessing/table_utils.py b/tensorflow/python/keras/layers/preprocessing/table_utils.py index 05447f6e9ff..16ac633f8dd 100644 --- a/tensorflow/python/keras/layers/preprocessing/table_utils.py +++ b/tensorflow/python/keras/layers/preprocessing/table_utils.py @@ -87,6 +87,8 @@ class TableHandler(object): self.table.lookup, inputs) indexed_data = ragged_functional_ops.map_flat_values( self._replace_oov_buckets, inputs, indexed_data) + # table.lookup is not shape-preserving, so we need to set the shape here. + indexed_data._set_shape(inputs.shape) # pylint: disable=protected-access # Composite tensors can pass tensor values through, which will cause # errors if all operations in the TF graph do so. We can break this chain # with an identity here. diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py index 057575d4ecc..28d339ea5b1 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py @@ -490,11 +490,12 @@ class TextVectorization(CombinerPreprocessingLayer): # in None for undefined shape axes. If using 'and !=', this causes the # expression to evaluate to False instead of True if the shape is undefined; # the expression needs to evaluate to True in that case. - if self._split is not None and not input_shape[1] == 1: # pylint: disable=g-comparison-negation - raise RuntimeError( - "When using TextVectorization to tokenize strings, the first " - "dimension of the input array must be 1, got shape " - "{}".format(input_shape)) + if self._split is not None: + if input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation + raise RuntimeError( + "When using TextVectorization to tokenize strings, the innermost " + "dimension of the input array must be 1, got shape " + "{}".format(input_shape)) super(TextVectorization, self).build(input_shape) @@ -536,7 +537,8 @@ class TextVectorization(CombinerPreprocessingLayer): # If we are splitting, we validate that the 1st axis is of dimension 1 and # so can be squeezed out. We do this here instead of after splitting for # performance reasons - it's more expensive to squeeze a ragged tensor. - inputs = array_ops.squeeze(inputs, axis=1) + if inputs.shape.ndims > 1: + inputs = array_ops.squeeze(inputs, axis=-1) if self._split == SPLIT_ON_WHITESPACE: # This treats multiple whitespaces as one whitespace, and strips leading # and trailing whitespace. @@ -561,8 +563,6 @@ class TextVectorization(CombinerPreprocessingLayer): def call(self, inputs): if isinstance(inputs, (list, tuple, np.ndarray)): inputs = ops.convert_to_tensor(inputs) - if inputs.shape.rank == 1: - inputs = array_ops.expand_dims(inputs, axis=-1) self._called = True inputs = self._preprocess(inputs) @@ -570,9 +570,7 @@ class TextVectorization(CombinerPreprocessingLayer): # If we're not doing any output processing, return right away. if self._output_mode is None: return inputs - indexed_data = self._index_lookup_layer(inputs) - if self._output_mode == INT: # Once we have the dense tensor, we can return it if we weren't given a # fixed output sequence length. If we were, though, we have to dynamically @@ -585,7 +583,6 @@ class TextVectorization(CombinerPreprocessingLayer): dense_data = indexed_data if self._output_sequence_length is None: - dense_data.set_shape(tensor_shape.TensorShape((None, None))) return dense_data else: sequence_len = K.shape(dense_data)[1] @@ -596,8 +593,9 @@ class TextVectorization(CombinerPreprocessingLayer): sequence_len < self._output_sequence_length, true_fn=pad_fn, false_fn=slice_fn) - output_tensor.set_shape( - tensor_shape.TensorShape((None, self._output_sequence_length))) + output_shape = output_tensor.shape.as_list() + output_shape[-1] = self._output_sequence_length + output_tensor.set_shape(tensor_shape.TensorShape(output_shape)) return output_tensor # If we're not returning integers here, we rely on the vectorization layer diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py index 5d909498d8a..508f222eac7 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py @@ -355,6 +355,59 @@ class TextVectorizationLayerTest(keras_parameterized.TestCase, if context.executing_eagerly(): self.assertAllClose(out.numpy(), [[2, 3], [4, 5]]) + @parameterized.named_parameters( + { + "testcase_name": "1d", + "data": ["0", "a", "b", "c", "d", "e", "a", "b", "c", "d", "f"], + "expected": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1] + }, + { + "testcase_name": "2d", + "data": [["0", "a", "b", "c", "d"], ["e", "a", "b", "c", "d"], ["f"]], + "expected": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 0, 0, 0, 0]] + }, + { + "testcase_name": + "3d", + "data": [[["0", "a", "b"], ["c", "d"]], [["e", "a"], ["b", "c", "d"]], + [["f"]]], + "expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]], + [[1, 0, 0], [0, 0, 0]]] + }, + ) + def test_layer_dimensionality_handling(self, data, expected): + vocab = ["a", "b", "c", "d"] + vectorization = get_layer_class()( + max_tokens=None, standardize=None, split=None, pad_to_max_tokens=False) + vectorization.set_vocabulary(vocab) + output = vectorization(ragged_factory_ops.constant(data)) + self.assertAllEqual(expected, output) + + @parameterized.named_parameters( + { + "testcase_name": "1d", + "data": ["0 a b c d e a b c d f"], + "expected": [[1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]] + }, + { + "testcase_name": + "3d", + "data": [[["0 a b"], ["c d"]], [["e a"], ["b c d"]], [["f"]]], + "expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]], + [[1, 0, 0], [0, 0, 0]]] + }, + ) + def test_layer_dimensionality_handling_with_split(self, data, expected): + vocab = ["a", "b", "c", "d"] + vectorization = get_layer_class()( + max_tokens=None, + standardize=None, + split=text_vectorization.SPLIT_ON_WHITESPACE, + pad_to_max_tokens=False) + vectorization.set_vocabulary(vocab) + output = vectorization(ragged_factory_ops.constant(data, inner_shape=(1,))) + self.assertAllEqual(expected, output) + @keras_parameterized.run_all_keras_modes class TextVectorizationPreprocessingTest( @@ -580,7 +633,7 @@ class TextVectorizationPreprocessingTest( split=text_vectorization.SPLIT_ON_WHITESPACE, output_mode=None) with self.assertRaisesRegex(RuntimeError, - ".*tokenize strings, the first dimension.*"): + ".*tokenize strings, the innermost dime.*"): _ = layer(input_data) def test_string_splitting_with_non_1d_raggedarray_fails(self): @@ -591,7 +644,7 @@ class TextVectorizationPreprocessingTest( split=text_vectorization.SPLIT_ON_WHITESPACE, output_mode=None) with self.assertRaisesRegex(RuntimeError, - ".*tokenize strings, the first dimension.*"): + ".*tokenize strings, the innermost dime.*"): _ = layer(input_data) def test_standardization_with_invalid_standardize_arg(self):