Fix dimensionality handling issues in TextVectorization.

PiperOrigin-RevId: 313206360
Change-Id: I5929e83b26011c975561e525b90aef6949a185b2
This commit is contained in:
A. Unique TensorFlower 2020-05-26 09:31:32 -07:00 committed by TensorFlower Gardener
parent 00664cef68
commit a92ff929b8
3 changed files with 68 additions and 15 deletions

View File

@ -87,6 +87,8 @@ class TableHandler(object):
self.table.lookup, inputs)
indexed_data = ragged_functional_ops.map_flat_values(
self._replace_oov_buckets, inputs, indexed_data)
# table.lookup is not shape-preserving, so we need to set the shape here.
indexed_data._set_shape(inputs.shape) # pylint: disable=protected-access
# Composite tensors can pass tensor values through, which will cause
# errors if all operations in the TF graph do so. We can break this chain
# with an identity here.

View File

@ -490,11 +490,12 @@ class TextVectorization(CombinerPreprocessingLayer):
# in None for undefined shape axes. If using 'and !=', this causes the
# expression to evaluate to False instead of True if the shape is undefined;
# the expression needs to evaluate to True in that case.
if self._split is not None and not input_shape[1] == 1: # pylint: disable=g-comparison-negation
raise RuntimeError(
"When using TextVectorization to tokenize strings, the first "
"dimension of the input array must be 1, got shape "
"{}".format(input_shape))
if self._split is not None:
if input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation
raise RuntimeError(
"When using TextVectorization to tokenize strings, the innermost "
"dimension of the input array must be 1, got shape "
"{}".format(input_shape))
super(TextVectorization, self).build(input_shape)
@ -536,7 +537,8 @@ class TextVectorization(CombinerPreprocessingLayer):
# If we are splitting, we validate that the 1st axis is of dimension 1 and
# so can be squeezed out. We do this here instead of after splitting for
# performance reasons - it's more expensive to squeeze a ragged tensor.
inputs = array_ops.squeeze(inputs, axis=1)
if inputs.shape.ndims > 1:
inputs = array_ops.squeeze(inputs, axis=-1)
if self._split == SPLIT_ON_WHITESPACE:
# This treats multiple whitespaces as one whitespace, and strips leading
# and trailing whitespace.
@ -561,8 +563,6 @@ class TextVectorization(CombinerPreprocessingLayer):
def call(self, inputs):
if isinstance(inputs, (list, tuple, np.ndarray)):
inputs = ops.convert_to_tensor(inputs)
if inputs.shape.rank == 1:
inputs = array_ops.expand_dims(inputs, axis=-1)
self._called = True
inputs = self._preprocess(inputs)
@ -570,9 +570,7 @@ class TextVectorization(CombinerPreprocessingLayer):
# If we're not doing any output processing, return right away.
if self._output_mode is None:
return inputs
indexed_data = self._index_lookup_layer(inputs)
if self._output_mode == INT:
# Once we have the dense tensor, we can return it if we weren't given a
# fixed output sequence length. If we were, though, we have to dynamically
@ -585,7 +583,6 @@ class TextVectorization(CombinerPreprocessingLayer):
dense_data = indexed_data
if self._output_sequence_length is None:
dense_data.set_shape(tensor_shape.TensorShape((None, None)))
return dense_data
else:
sequence_len = K.shape(dense_data)[1]
@ -596,8 +593,9 @@ class TextVectorization(CombinerPreprocessingLayer):
sequence_len < self._output_sequence_length,
true_fn=pad_fn,
false_fn=slice_fn)
output_tensor.set_shape(
tensor_shape.TensorShape((None, self._output_sequence_length)))
output_shape = output_tensor.shape.as_list()
output_shape[-1] = self._output_sequence_length
output_tensor.set_shape(tensor_shape.TensorShape(output_shape))
return output_tensor
# If we're not returning integers here, we rely on the vectorization layer

View File

@ -355,6 +355,59 @@ class TextVectorizationLayerTest(keras_parameterized.TestCase,
if context.executing_eagerly():
self.assertAllClose(out.numpy(), [[2, 3], [4, 5]])
@parameterized.named_parameters(
{
"testcase_name": "1d",
"data": ["0", "a", "b", "c", "d", "e", "a", "b", "c", "d", "f"],
"expected": [1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]
},
{
"testcase_name": "2d",
"data": [["0", "a", "b", "c", "d"], ["e", "a", "b", "c", "d"], ["f"]],
"expected": [[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 0, 0, 0, 0]]
},
{
"testcase_name":
"3d",
"data": [[["0", "a", "b"], ["c", "d"]], [["e", "a"], ["b", "c", "d"]],
[["f"]]],
"expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]],
[[1, 0, 0], [0, 0, 0]]]
},
)
def test_layer_dimensionality_handling(self, data, expected):
vocab = ["a", "b", "c", "d"]
vectorization = get_layer_class()(
max_tokens=None, standardize=None, split=None, pad_to_max_tokens=False)
vectorization.set_vocabulary(vocab)
output = vectorization(ragged_factory_ops.constant(data))
self.assertAllEqual(expected, output)
@parameterized.named_parameters(
{
"testcase_name": "1d",
"data": ["0 a b c d e a b c d f"],
"expected": [[1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1]]
},
{
"testcase_name":
"3d",
"data": [[["0 a b"], ["c d"]], [["e a"], ["b c d"]], [["f"]]],
"expected": [[[1, 2, 3], [4, 5, 0]], [[1, 2, 0], [3, 4, 5]],
[[1, 0, 0], [0, 0, 0]]]
},
)
def test_layer_dimensionality_handling_with_split(self, data, expected):
vocab = ["a", "b", "c", "d"]
vectorization = get_layer_class()(
max_tokens=None,
standardize=None,
split=text_vectorization.SPLIT_ON_WHITESPACE,
pad_to_max_tokens=False)
vectorization.set_vocabulary(vocab)
output = vectorization(ragged_factory_ops.constant(data, inner_shape=(1,)))
self.assertAllEqual(expected, output)
@keras_parameterized.run_all_keras_modes
class TextVectorizationPreprocessingTest(
@ -580,7 +633,7 @@ class TextVectorizationPreprocessingTest(
split=text_vectorization.SPLIT_ON_WHITESPACE,
output_mode=None)
with self.assertRaisesRegex(RuntimeError,
".*tokenize strings, the first dimension.*"):
".*tokenize strings, the innermost dime.*"):
_ = layer(input_data)
def test_string_splitting_with_non_1d_raggedarray_fails(self):
@ -591,7 +644,7 @@ class TextVectorizationPreprocessingTest(
split=text_vectorization.SPLIT_ON_WHITESPACE,
output_mode=None)
with self.assertRaisesRegex(RuntimeError,
".*tokenize strings, the first dimension.*"):
".*tokenize strings, the innermost dime.*"):
_ = layer(input_data)
def test_standardization_with_invalid_standardize_arg(self):