diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py index 3f9c36eb4ce..6080d0b7ede 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization.py @@ -568,7 +568,7 @@ class TextVectorization(CombinerPreprocessingLayer): self.set_vocabulary(updates[_VOCAB_NAME]) def _preprocess(self, inputs): - if self._standardize is LOWER_AND_STRIP_PUNCTUATION: + if self._standardize == LOWER_AND_STRIP_PUNCTUATION: lowercase_inputs = gen_string_ops.string_lower(inputs) inputs = string_ops.regex_replace(lowercase_inputs, DEFAULT_STRIP_REGEX, "") @@ -586,7 +586,7 @@ class TextVectorization(CombinerPreprocessingLayer): # so can be squeezed out. We do this here instead of after splitting for # performance reasons - it's more expensive to squeeze a ragged tensor. inputs = array_ops.squeeze(inputs, axis=1) - if self._split is SPLIT_ON_WHITESPACE: + if self._split == SPLIT_ON_WHITESPACE: # This treats multiple whitespaces as one whitespace, and strips leading # and trailing whitespace. inputs = ragged_string_ops.string_split_v2(inputs) diff --git a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py index 8f65c481400..b20b0164247 100644 --- a/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py +++ b/tensorflow/python/keras/layers/preprocessing/text_vectorization_test.py @@ -470,6 +470,34 @@ class TextVectorizationPreprocessingTest( with self.assertRaisesRegex(ValueError, ".*is not a supported splitting.*"): _ = layer(input_data) + def test_standardize_with_no_identical_argument(self): + input_array = np.array([["hello world"]]) + expected_output = np.array([[1, 1]]) + + standardize = "".join(["lower", "_and_strip_punctuation"]) + layer = get_layer_class()(standardize=standardize) + + input_data = keras.Input(shape=(1,), dtype=dtypes.string) + output_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=output_data) + output = model.predict(input_array) + + self.assertAllEqual(expected_output, output) + + def test_splitting_with_no_identical_argument(self): + input_array = np.array([["hello world"]]) + expected_output = np.array([[1, 1]]) + + split = "".join(["white", "space"]) + layer = get_layer_class()(split=split) + + input_data = keras.Input(shape=(1,), dtype=dtypes.string) + output_data = layer(input_data) + model = keras.Model(inputs=input_data, outputs=output_data) + output = model.predict(input_array) + + self.assertAllEqual(expected_output, output) + @keras_parameterized.run_all_keras_modes class TextVectorizationOutputTest(