Merge pull request #34420 from dfalbel:bugfix/text-preprocessing-standardize
PiperOrigin-RevId: 285786331 Change-Id: Icd606458dce4beebfa0550277c443cd20d39f2c1
This commit is contained in:
commit
dcbe8eaaac
@ -568,7 +568,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
self.set_vocabulary(updates[_VOCAB_NAME])
|
self.set_vocabulary(updates[_VOCAB_NAME])
|
||||||
|
|
||||||
def _preprocess(self, inputs):
|
def _preprocess(self, inputs):
|
||||||
if self._standardize is LOWER_AND_STRIP_PUNCTUATION:
|
if self._standardize == LOWER_AND_STRIP_PUNCTUATION:
|
||||||
lowercase_inputs = gen_string_ops.string_lower(inputs)
|
lowercase_inputs = gen_string_ops.string_lower(inputs)
|
||||||
inputs = string_ops.regex_replace(lowercase_inputs, DEFAULT_STRIP_REGEX,
|
inputs = string_ops.regex_replace(lowercase_inputs, DEFAULT_STRIP_REGEX,
|
||||||
"")
|
"")
|
||||||
@ -586,7 +586,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||||||
# so can be squeezed out. We do this here instead of after splitting for
|
# so can be squeezed out. We do this here instead of after splitting for
|
||||||
# performance reasons - it's more expensive to squeeze a ragged tensor.
|
# performance reasons - it's more expensive to squeeze a ragged tensor.
|
||||||
inputs = array_ops.squeeze(inputs, axis=1)
|
inputs = array_ops.squeeze(inputs, axis=1)
|
||||||
if self._split is SPLIT_ON_WHITESPACE:
|
if self._split == SPLIT_ON_WHITESPACE:
|
||||||
# This treats multiple whitespaces as one whitespace, and strips leading
|
# This treats multiple whitespaces as one whitespace, and strips leading
|
||||||
# and trailing whitespace.
|
# and trailing whitespace.
|
||||||
inputs = ragged_string_ops.string_split_v2(inputs)
|
inputs = ragged_string_ops.string_split_v2(inputs)
|
||||||
|
@ -470,6 +470,34 @@ class TextVectorizationPreprocessingTest(
|
|||||||
with self.assertRaisesRegex(ValueError, ".*is not a supported splitting.*"):
|
with self.assertRaisesRegex(ValueError, ".*is not a supported splitting.*"):
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
|
|
||||||
|
def test_standardize_with_no_identical_argument(self):
|
||||||
|
input_array = np.array([["hello world"]])
|
||||||
|
expected_output = np.array([[1, 1]])
|
||||||
|
|
||||||
|
standardize = "".join(["lower", "_and_strip_punctuation"])
|
||||||
|
layer = get_layer_class()(standardize=standardize)
|
||||||
|
|
||||||
|
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
|
||||||
|
output_data = layer(input_data)
|
||||||
|
model = keras.Model(inputs=input_data, outputs=output_data)
|
||||||
|
output = model.predict(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_output, output)
|
||||||
|
|
||||||
|
def test_splitting_with_no_identical_argument(self):
|
||||||
|
input_array = np.array([["hello world"]])
|
||||||
|
expected_output = np.array([[1, 1]])
|
||||||
|
|
||||||
|
split = "".join(["white", "space"])
|
||||||
|
layer = get_layer_class()(split=split)
|
||||||
|
|
||||||
|
input_data = keras.Input(shape=(1,), dtype=dtypes.string)
|
||||||
|
output_data = layer(input_data)
|
||||||
|
model = keras.Model(inputs=input_data, outputs=output_data)
|
||||||
|
output = model.predict(input_array)
|
||||||
|
|
||||||
|
self.assertAllEqual(expected_output, output)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class TextVectorizationOutputTest(
|
class TextVectorizationOutputTest(
|
||||||
|
Loading…
Reference in New Issue
Block a user