Allow partial padding_values definitions with 'None' to indicate default values
For example, a nested structure like ([], [], {'hello': [], 'world': []}) could have a padding_values provided like (0, 1, None) to indicate that the third tuple element should have its nested structure be default-padded. This makes it easier to specify the padding of only one of the values, when the rest is fine with defaults. PiperOrigin-RevId: 286453825 Change-Id: I1aae75bd3c25b55175751504f79d116c3f9ee267
This commit is contained in:
parent
fc4d953221
commit
f1b0ec7596
@ -99,18 +99,26 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
batch_size=4, padded_shapes=[-1]))
|
||||
self.assertDatasetProduces(dataset, expected_output=[[[], [], [], []]])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPaddedBatchDatasetNonDefaultPadding(self):
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
padding_values=[(-1, '<end>', {'structure': ''}),
|
||||
(-1, '<end>', None)])))
|
||||
def testPaddedBatchDatasetNonDefaultPadding(self, padding_values):
|
||||
|
||||
def fill_tuple(x):
|
||||
filled = array_ops.fill([x], x)
|
||||
return (filled, string_ops.as_string(filled))
|
||||
return (filled, string_ops.as_string(filled), {
|
||||
'structure': string_ops.as_string(filled)
|
||||
})
|
||||
|
||||
random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32)
|
||||
dataset = (
|
||||
dataset_ops.Dataset.from_tensor_slices(random_seq_lens).map(fill_tuple)
|
||||
.padded_batch(
|
||||
4, padded_shapes=([-1], [-1]), padding_values=(-1, '<end>')))
|
||||
4, padded_shapes=([-1], [-1], {'structure': [-1]}),
|
||||
padding_values=padding_values))
|
||||
|
||||
get_next = self.getNext(dataset)
|
||||
for i in range(8):
|
||||
@ -118,6 +126,7 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
padded_len = np.max(result[0])
|
||||
self.assertEqual((4, padded_len), result[0].shape)
|
||||
self.assertEqual((4, padded_len), result[1].shape)
|
||||
self.assertEqual((4, padded_len), result[2]['structure'].shape)
|
||||
for j in range(4):
|
||||
seq_len = random_seq_lens[(i * 4) + j]
|
||||
self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len)
|
||||
@ -127,6 +136,10 @@ class PaddedBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
[compat.as_bytes(str(seq_len))] * seq_len)
|
||||
self.assertAllEqual(result[1][j, seq_len:],
|
||||
[b'<end>'] * (padded_len - seq_len))
|
||||
self.assertAllEqual(result[2]['structure'][j, :seq_len],
|
||||
[compat.as_bytes(str(seq_len))] * seq_len)
|
||||
self.assertAllEqual(result[2]['structure'][j, seq_len:],
|
||||
[b''] * (padded_len - seq_len))
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(get_next())
|
||||
|
||||
|
@ -1457,8 +1457,9 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
maximum size of that dimension in each batch.
|
||||
padding_values: (Optional.) A nested structure of scalar-shaped
|
||||
`tf.Tensor`, representing the padding values to use for the respective
|
||||
components. Defaults are `0` for numeric types and the empty string for
|
||||
string types.
|
||||
components. None represents that the nested structure should be padded
|
||||
with default values. Defaults are `0` for numeric types and the empty
|
||||
string for string types.
|
||||
drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
|
||||
whether the last batch should be dropped in the case it has fewer than
|
||||
`batch_size` elements; the default behavior is not to drop the smaller
|
||||
@ -3769,8 +3770,8 @@ def _padding_value_to_tensor(value, output_type):
|
||||
return value
|
||||
|
||||
|
||||
def _default_padding(input_dataset):
|
||||
"""Returns default padding tensors in a structure matching `input_dataset`."""
|
||||
def _padding_values_or_default(padding_values, input_dataset):
|
||||
"""Returns padding values with None elements replaced with default values."""
|
||||
def make_zero(t):
|
||||
if t.base_dtype == dtypes.string:
|
||||
return ""
|
||||
@ -3782,9 +3783,13 @@ def _default_padding(input_dataset):
|
||||
raise TypeError(error_msg)
|
||||
else:
|
||||
return np.zeros_like(t.as_numpy_dtype())
|
||||
def value_or_default(value, default):
|
||||
return default if value is None else value
|
||||
|
||||
return nest.map_structure(
|
||||
make_zero, get_legacy_output_types(input_dataset))
|
||||
default_padding = nest.map_structure(make_zero,
|
||||
get_legacy_output_types(input_dataset))
|
||||
return nest.map_structure_up_to(padding_values, value_or_default,
|
||||
padding_values, default_padding)
|
||||
|
||||
|
||||
class PaddedBatchDataset(UnaryDataset):
|
||||
@ -3801,9 +3806,7 @@ class PaddedBatchDataset(UnaryDataset):
|
||||
self._input_dataset = input_dataset
|
||||
self._batch_size = ops.convert_to_tensor(
|
||||
batch_size, dtype=dtypes.int64, name="batch_size")
|
||||
padding_values = (
|
||||
padding_values
|
||||
if padding_values is not None else _default_padding(input_dataset))
|
||||
padding_values = _padding_values_or_default(padding_values, input_dataset)
|
||||
|
||||
input_shapes = get_legacy_output_shapes(input_dataset)
|
||||
flat_padded_shapes = nest.flatten_up_to(input_shapes, padded_shapes)
|
||||
|
Loading…
x
Reference in New Issue
Block a user