Fix dimensionality handling issues in TextVectorization.
PiperOrigin-RevId: 313206360 Change-Id: I5929e83b26011c975561e525b90aef6949a185b2
This commit is contained in:
parent
00664cef68
commit
a92ff929b8
|
@ -87,6 +87,8 @@ class TableHandler(object):
|
||||||
self.table.lookup, inputs)
|
self.table.lookup, inputs)
|
||||||
indexed_data = ragged_functional_ops.map_flat_values(
|
indexed_data = ragged_functional_ops.map_flat_values(
|
||||||
self._replace_oov_buckets, inputs, indexed_data)
|
self._replace_oov_buckets, inputs, indexed_data)
|
||||||
|
# table.lookup is not shape-preserving, so we need to set the shape here.
|
||||||
|
indexed_data._set_shape(inputs.shape) # pylint: disable=protected-access
|
||||||
# Composite tensors can pass tensor values through, which will cause
|
# Composite tensors can pass tensor values through, which will cause
|
||||||
# errors if all operations in the TF graph do so. We can break this chain
|
# errors if all operations in the TF graph do so. We can break this chain
|
||||||
# with an identity here.
|
# with an identity here.
|
||||||
|
|
|
@ -490,11 +490,12 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||||
# in None for undefined shape axes. If using 'and !=', this causes the
|
# in None for undefined shape axes. If using 'and !=', this causes the
|
||||||
# expression to evaluate to False instead of True if the shape is undefined;
|
# expression to evaluate to False instead of True if the shape is undefined;
|
||||||
# the expression needs to evaluate to True in that case.
|
# the expression needs to evaluate to True in that case.
|
||||||
if self._split is not None and not input_shape[1] == 1: # pylint: disable=g-comparison-negation
|
if self._split is not None:
|
||||||
raise RuntimeError(
|
if input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation
|
||||||
"When using TextVectorization to tokenize strings, the first "
|
raise RuntimeError(
|
||||||
"dimension of the input array must be 1, got shape "
|
"When using TextVectorization to tokenize strings, the innermost "
|
||||||
"{}".format(input_shape))
|
"dimension of the input array must be 1, got shape "
|
||||||
|
"{}".format(input_shape))
|
||||||
|
|
||||||
super(TextVectorization, self).build(input_shape)
|
super(TextVectorization, self).build(input_shape)
|
||||||
|
|
||||||
|
@ -536,7 +537,8 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||||
# If we are splitting, we validate that the 1st axis is of dimension 1 and
|
# If we are splitting, we validate that the 1st axis is of dimension 1 and
|
||||||
# so can be squeezed out. We do this here instead of after splitting for
|
# so can be squeezed out. We do this here instead of after splitting for
|
||||||
# performance reasons - it's more expensive to squeeze a ragged tensor.
|
# performance reasons - it's more expensive to squeeze a ragged tensor.
|
||||||
inputs = array_ops.squeeze(inputs, axis=1)
|
if inputs.shape.ndims > 1:
|
||||||
|
inputs = array_ops.squeeze(inputs, axis=-1)
|
||||||
if self._split == SPLIT_ON_WHITESPACE:
|
if self._split == SPLIT_ON_WHITESPACE:
|
||||||
# This treats multiple whitespaces as one whitespace, and strips leading
|
# This treats multiple whitespaces as one whitespace, and strips leading
|
||||||
# and trailing whitespace.
|
# and trailing whitespace.
|
||||||
|
@ -561,8 +563,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
if isinstance(inputs, (list, tuple, np.ndarray)):
|
if isinstance(inputs, (list, tuple, np.ndarray)):
|
||||||
inputs = ops.convert_to_tensor(inputs)
|
inputs = ops.convert_to_tensor(inputs)
|
||||||
if inputs.shape.rank == 1:
|
|
||||||
inputs = array_ops.expand_dims(inputs, axis=-1)
|
|
||||||
|
|
||||||
self._called = True
|
self._called = True
|
||||||
inputs = self._preprocess(inputs)
|
inputs = self._preprocess(inputs)
|
||||||
|
@ -570,9 +570,7 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||||
# If we're not doing any output processing, return right away.
|
# If we're not doing any output processing, return right away.
|
||||||
if self._output_mode is None:
|
if self._output_mode is None:
|
||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
indexed_data = self._index_lookup_layer(inputs)
|
indexed_data = self._index_lookup_layer(inputs)
|
||||||
|
|
||||||
if self._output_mode == INT:
|
if self._output_mode == INT:
|
||||||
# Once we have the dense tensor, we can return it if we weren't given a
|
# Once we have the dense tensor, we can return it if we weren't given a
|
||||||
# fixed output sequence length. If we were, though, we have to dynamically
|
# fixed output sequence length. If we were, though, we have to dynamically
|
||||||
|
@ -585,7 +583,6 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||||
dense_data = indexed_data
|
dense_data = indexed_data
|
||||||
|
|
||||||
if self._output_sequence_length is None:
|
if self._output_sequence_length is None:
|
||||||
dense_data.set_shape(tensor_shape.TensorShape((None, None)))
|
|
||||||
return dense_data
|
return dense_data
|
||||||
else:
|
else:
|
||||||
sequence_len = K.shape(dense_data)[1]
|
sequence_len = K.shape(dense_data)[1]
|
||||||
|
@ -596,8 +593,9 @@ class TextVectorization(CombinerPreprocessingLayer):
|
||||||
sequence_len < self._output_sequence_length,
|
sequence_len < self._output_sequence_length,
|
||||||
true_fn=pad_fn,
|
true_fn=pad_fn,
|
||||||
false_fn=slice_fn)
|
false_fn=slice_fn)
|
||||||
output_tensor.set_shape(
|
output_shape = output_tensor.shape.as_list()
|
||||||
tensor_shape.TensorShape((None, self._output_sequence_length)))
|
output_shape[-1] = self._output_sequence_length
|
||||||
|
output_tensor.set_shape(tensor_shape.TensorShape(output_shape))
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
# If we're not returning integers here, we rely on the vectorization layer
|
# If we're not returning integers here, we rely on the vectorization layer
|
||||||
|
|
|
@ -355,6 +355,59 @@ class TextVectorizationLayerTest(keras_parameterized.TestCase,
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
|
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
"testcase_name": "1d",
|
||||||
|
"data": ["0", "a", "b", "c", "d", "e", "a", "b", "c", "d", "f"],
|
||||||
|
"expected": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name": "2d",
|
||||||
|
"data": [["0", "a", "b", "c", "d"], ["e", "a", "b", "c", "d"], ["f"]],
|
||||||
|
"expected": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 0, 0, 0, 0]]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"3d",
|
||||||
|
"data": [[["0", "a", "b"], ["c", "d"]], [["e", "a"], ["b", "c", "d"]],
|
||||||
|
[["f"]]],
|
||||||
|
"expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]],
|
||||||
|
[[1, 0, 0], [0, 0, 0]]]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_layer_dimensionality_handling(self, data, expected):
|
||||||
|
vocab = ["a", "b", "c", "d"]
|
||||||
|
vectorization = get_layer_class()(
|
||||||
|
max_tokens=None, standardize=None, split=None, pad_to_max_tokens=False)
|
||||||
|
vectorization.set_vocabulary(vocab)
|
||||||
|
output = vectorization(ragged_factory_ops.constant(data))
|
||||||
|
self.assertAllEqual(expected, output)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
{
|
||||||
|
"testcase_name": "1d",
|
||||||
|
"data": ["0 a b c d e a b c d f"],
|
||||||
|
"expected": [[1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"testcase_name":
|
||||||
|
"3d",
|
||||||
|
"data": [[["0 a b"], ["c d"]], [["e a"], ["b c d"]], [["f"]]],
|
||||||
|
"expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]],
|
||||||
|
[[1, 0, 0], [0, 0, 0]]]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_layer_dimensionality_handling_with_split(self, data, expected):
|
||||||
|
vocab = ["a", "b", "c", "d"]
|
||||||
|
vectorization = get_layer_class()(
|
||||||
|
max_tokens=None,
|
||||||
|
standardize=None,
|
||||||
|
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
||||||
|
pad_to_max_tokens=False)
|
||||||
|
vectorization.set_vocabulary(vocab)
|
||||||
|
output = vectorization(ragged_factory_ops.constant(data, inner_shape=(1,)))
|
||||||
|
self.assertAllEqual(expected, output)
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class TextVectorizationPreprocessingTest(
|
class TextVectorizationPreprocessingTest(
|
||||||
|
@ -580,7 +633,7 @@ class TextVectorizationPreprocessingTest(
|
||||||
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
||||||
output_mode=None)
|
output_mode=None)
|
||||||
with self.assertRaisesRegex(RuntimeError,
|
with self.assertRaisesRegex(RuntimeError,
|
||||||
".*tokenize strings, the first dimension.*"):
|
".*tokenize strings, the innermost dime.*"):
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
|
|
||||||
def test_string_splitting_with_non_1d_raggedarray_fails(self):
|
def test_string_splitting_with_non_1d_raggedarray_fails(self):
|
||||||
|
@ -591,7 +644,7 @@ class TextVectorizationPreprocessingTest(
|
||||||
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
split=text_vectorization.SPLIT_ON_WHITESPACE,
|
||||||
output_mode=None)
|
output_mode=None)
|
||||||
with self.assertRaisesRegex(RuntimeError,
|
with self.assertRaisesRegex(RuntimeError,
|
||||||
".*tokenize strings, the first dimension.*"):
|
".*tokenize strings, the innermost dime.*"):
|
||||||
_ = layer(input_data)
|
_ = layer(input_data)
|
||||||
|
|
||||||
def test_standardization_with_invalid_standardize_arg(self):
|
def test_standardization_with_invalid_standardize_arg(self):
|
||||||
|
|
Loading…
Reference in New Issue