Merge pull request #34420 from dfalbel:bugfix/text-preprocessing-standardize
PiperOrigin-RevId: 285786331 Change-Id: Icd606458dce4beebfa0550277c443cd20d39f2c1
This commit is contained in:
commit
dcbe8eaaac
@ -568,7 +568,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
self.set_vocabulary(updates[_VOCAB_NAME])
|
||||
|
||||
def _preprocess(self, inputs):
|
||||
if self._standardize is LOWER_AND_STRIP_PUNCTUATION:
|
||||
if self._standardize == LOWER_AND_STRIP_PUNCTUATION:
|
||||
lowercase_inputs = gen_string_ops.string_lower(inputs)
|
||||
inputs = string_ops.regex_replace(lowercase_inputs, DEFAULT_STRIP_REGEX,
|
||||
"")
|
||||
@ -586,7 +586,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||
# so can be squeezed out. We do this here instead of after splitting for
|
||||
# performance reasons - it's more expensive to squeeze a ragged tensor.
|
||||
inputs = array_ops.squeeze(inputs, axis=1)
|
||||
if self._split is SPLIT_ON_WHITESPACE:
|
||||
if self._split == SPLIT_ON_WHITESPACE:
|
||||
# This treats multiple whitespaces as one whitespace, and strips leading
|
||||
# and trailing whitespace.
|
||||
inputs = ragged_string_ops.string_split_v2(inputs)
|
||||
|
@ -470,6 +470,34 @@ class TextVectorizationPreprocessingTest(
|
||||
with self.assertRaisesRegex(ValueError, ".*is not a supported splitting.*"):
|
||||
_ = layer(input_data)
|
||||
|
||||
def test_standardize_with_no_identical_argument(self):
|
||||
input_array = np.array([["hello world"]])
|
||||
expected_output = np.array([[1, 1]])
|
||||
|
||||
standardize = "".join(["lower", "_and_strip_punctuation"])
|
||||
layer = get_layer_class()(standardize=standardize)
|
||||
|
||||
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
|
||||
output_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=output_data)
|
||||
output = model.predict(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
def test_splitting_with_no_identical_argument(self):
|
||||
input_array = np.array([["hello world"]])
|
||||
expected_output = np.array([[1, 1]])
|
||||
|
||||
split = "".join(["white", "space"])
|
||||
layer = get_layer_class()(split=split)
|
||||
|
||||
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
|
||||
output_data = layer(input_data)
|
||||
model = keras.Model(inputs=input_data, outputs=output_data)
|
||||
output = model.predict(input_array)
|
||||
|
||||
self.assertAllEqual(expected_output, output)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class TextVectorizationOutputTest(
|
||||
|
Loading…
Reference in New Issue
Block a user