Merge pull request #34420 from dfalbel:bugfix/text-preprocessing-standardize

PiperOrigin-RevId: 285786331
Change-Id: Icd606458dce4beebfa0550277c443cd20d39f2c1
This commit is contained in:
TensorFlower Gardener 2019-12-16 09:27:04 -08:00
commit dcbe8eaaac
2 changed files with 30 additions and 2 deletions

View File

@ -568,7 +568,7 @@ class TextVectorization(CombinerPreprocessingLayer):
self.set_vocabulary(updates[_VOCAB_NAME])
def _preprocess(self, inputs):
if self._standardize is LOWER_AND_STRIP_PUNCTUATION:
if self._standardize == LOWER_AND_STRIP_PUNCTUATION:
lowercase_inputs = gen_string_ops.string_lower(inputs)
inputs = string_ops.regex_replace(lowercase_inputs, DEFAULT_STRIP_REGEX,
"")
@ -586,7 +586,7 @@ class TextVectorization(CombinerPreprocessingLayer):
# so can be squeezed out. We do this here instead of after splitting for
# performance reasons - it's more expensive to squeeze a ragged tensor.
inputs = array_ops.squeeze(inputs, axis=1)
if self._split is SPLIT_ON_WHITESPACE:
if self._split == SPLIT_ON_WHITESPACE:
# This treats multiple whitespaces as one whitespace, and strips leading
# and trailing whitespace.
inputs = ragged_string_ops.string_split_v2(inputs)

View File

@ -470,6 +470,34 @@ class TextVectorizationPreprocessingTest(
with self.assertRaisesRegex(ValueError, ".*is not a supported splitting.*"):
_ = layer(input_data)
def test_standardize_with_no_identical_argument(self):
input_array = np.array([["hello world"]])
expected_output = np.array([[1, 1]])
standardize = "".join(["lower", "_and_strip_punctuation"])
layer = get_layer_class()(standardize=standardize)
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
output_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=output_data)
output = model.predict(input_array)
self.assertAllEqual(expected_output, output)
def test_splitting_with_no_identical_argument(self):
input_array = np.array([["hello world"]])
expected_output = np.array([[1, 1]])
split = "".join(["white", "space"])
layer = get_layer_class()(split=split)
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
output_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=output_data)
output = model.predict(input_array)
self.assertAllEqual(expected_output, output)
@keras_parameterized.run_all_keras_modes
class TextVectorizationOutputTest(