Fix dimensionality handling issues in TextVectorization.
PiperOrigin-RevId: 313206360 Change-Id: I5929e83b26011c975561e525b90aef6949a185b2
This commit is contained in:
parent
00664cef68
commit
a92ff929b8
|
@ -87,6 +87,8 @@ class TableHandler(object):
|
|||
self.table.lookup, inputs)
|
||||
indexed_data = ragged_functional_ops.map_flat_values(
|
||||
self._replace_oov_buckets, inputs, indexed_data)
|
||||
# table.lookup is not shape-preserving, so we need to set the shape here.
|
||||
indexed_data._set_shape(inputs.shape) # pylint: disable=protected-access
|
||||
# Composite tensors can pass tensor values through, which will cause
|
||||
# errors if all operations in the TF graph do so. We can break this chain
|
||||
# with an identity here.
|
||||
|
|
|
@ -490,9 +490,10 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||
# in None for undefined shape axes. If using 'and !=', this causes the
|
||||
# expression to evaluate to False instead of True if the shape is undefined;
|
||||
# the expression needs to evaluate to True in that case.
|
||||
if self._split is not None and not input_shape[1] == 1: # pylint: disable=g-comparison-negation
|
||||
if self._split is not None:
|
||||
if input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation
|
||||
raise RuntimeError(
|
||||
"When using TextVectorization to tokenize strings, the first "
|
||||
"When using TextVectorization to tokenize strings, the innermost "
|
||||
"dimension of the input array must be 1, got shape "
|
||||
"{}".format(input_shape))
|
||||
|
||||
|
@ -536,7 +537,8 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||
# If we are splitting, we validate that the 1st axis is of dimension 1 and
|
||||
# so can be squeezed out. We do this here instead of after splitting for
|
||||
# performance reasons - it's more expensive to squeeze a ragged tensor.
|
||||
inputs = array_ops.squeeze(inputs, axis=1)
|
||||
if inputs.shape.ndims > 1:
|
||||
inputs = array_ops.squeeze(inputs, axis=-1)
|
||||
if self._split == SPLIT_ON_WHITESPACE:
|
||||
# This treats multiple whitespaces as one whitespace, and strips leading
|
||||
# and trailing whitespace.
|
||||
|
@ -561,8 +563,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||
def call(self, inputs):
|
||||
if isinstance(inputs, (list, tuple, np.ndarray)):
|
||||
inputs = ops.convert_to_tensor(inputs)
|
||||
if inputs.shape.rank == 1:
|
||||
inputs = array_ops.expand_dims(inputs, axis=-1)
|
||||
|
||||
self._called = True
|
||||
inputs = self._preprocess(inputs)
|
||||
|
@ -570,9 +570,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||
# If we're not doing any output processing, return right away.
|
||||
if self._output_mode is None:
|
||||
return inputs
|
||||
|
||||
indexed_data = self._index_lookup_layer(inputs)
|
||||
|
||||
if self._output_mode == INT:
|
||||
# Once we have the dense tensor, we can return it if we weren't given a
|
||||
# fixed output sequence length. If we were, though, we have to dynamically
|
||||
|
@ -585,7 +583,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||
dense_data = indexed_data
|
||||
|
||||
if self._output_sequence_length is None:
|
||||
dense_data.set_shape(tensor_shape.TensorShape((None, None)))
|
||||
return dense_data
|
||||
else:
|
||||
sequence_len = K.shape(dense_data)[1]
|
||||
|
@ -596,8 +593,9 @@ class TextVectorization(CombinerPreprocessingLayer):
|
|||
sequence_len < self._output_sequence_length,
|
||||
true_fn=pad_fn,
|
||||
false_fn=slice_fn)
|
||||
output_tensor.set_shape(
|
||||
tensor_shape.TensorShape((None, self._output_sequence_length)))
|
||||
output_shape = output_tensor.shape.as_list()
|
||||
output_shape[-1] = self._output_sequence_length
|
||||
output_tensor.set_shape(tensor_shape.TensorShape(output_shape))
|
||||
return output_tensor
|
||||
|
||||
# If we're not returning integers here, we rely on the vectorization layer
|
||||
|
|
|
@ -355,6 +355,59 @@ class TextVectorizationLayerTest(keras_parameterized.TestCase,
|
|||
if context.executing_eagerly():
|
||||
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "1d",
|
||||
"data": ["0", "a", "b", "c", "d", "e", "a", "b", "c", "d", "f"],
|
||||
"expected": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]
|
||||
},
|
||||
{
|
||||
"testcase_name": "2d",
|
||||
"data": [["0", "a", "b", "c", "d"], ["e", "a", "b", "c", "d"], ["f"]],
|
||||
"expected": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 0, 0, 0, 0]]
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"3d",
|
||||
"data": [[["0", "a", "b"], ["c", "d"]], [["e", "a"], ["b", "c", "d"]],
|
||||
[["f"]]],
|
||||
"expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]],
|
||||
[[1, 0, 0], [0, 0, 0]]]
|
||||
},
|
||||
)
|
||||
def test_layer_dimensionality_handling(self, data, expected):
|
||||
vocab = ["a", "b", "c", "d"]
|
||||
vectorization = get_layer_class()(
|
||||
max_tokens=None, standardize=None, split=None, pad_to_max_tokens=False)
|
||||
vectorization.set_vocabulary(vocab)
|
||||
output = vectorization(ragged_factory_ops.constant(data))
|
||||
self.assertAllEqual(expected, output)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{
|
||||
"testcase_name": "1d",
|
||||
"data": ["0 a b c d e a b c d f"],
|
||||
"expected": [[1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]]
|
||||
},
|
||||
{
|
||||
"testcase_name":
|
||||
"3d",
|
||||
"data": [[["0 a b"], ["c d"]], [["e a"], ["b c d"]], [["f"]]],
|
||||
"expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]],
|
||||
[[1, 0, 0], [0, 0, 0]]]
|
||||
},
|
||||
)
|
||||
def test_layer_dimensionality_handling_with_split(self, data, expected):
|
||||
vocab = ["a", "b", "c", "d"]
|
||||
vectorization = get_layer_class()(
|
||||
max_tokens=None,
|
||||
standardize=None,
|
||||
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
||||
pad_to_max_tokens=False)
|
||||
vectorization.set_vocabulary(vocab)
|
||||
output = vectorization(ragged_factory_ops.constant(data, inner_shape=(1,)))
|
||||
self.assertAllEqual(expected, output)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
class TextVectorizationPreprocessingTest(
|
||||
|
@ -580,7 +633,7 @@ class TextVectorizationPreprocessingTest(
|
|||
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
||||
output_mode=None)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
".*tokenize strings, the first dimension.*"):
|
||||
".*tokenize strings, the innermost dime.*"):
|
||||
_ = layer(input_data)
|
||||
|
||||
def test_string_splitting_with_non_1d_raggedarray_fails(self):
|
||||
|
@ -591,7 +644,7 @@ class TextVectorizationPreprocessingTest(
|
|||
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
||||
output_mode=None)
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
".*tokenize strings, the first dimension.*"):
|
||||
".*tokenize strings, the innermost dime.*"):
|
||||
_ = layer(input_data)
|
||||
|
||||
def test_standardization_with_invalid_standardize_arg(self):
|
||||
|
|
Loading…
Reference in New Issue