Move the recovering tensor shape logic of DistributedIterator get_next_as_optional earlier, so it covers more use cases.

PiperOrigin-RevId: 314277550
Change-Id: I2c8f3afcc791b02310b5242251f335d9da7cd5bf
This commit is contained in:
Ruoxin Sang 2020-06-02 00:03:54 -07:00 committed by TensorFlower Gardener
parent 0482d18b52
commit bc49458b14
2 changed files with 71 additions and 27 deletions
tensorflow/python/distribute

View File

@ -349,33 +349,6 @@ class DistributedIteratorBase(distribute_types.Iterator):
results.append(result)
replicas = results
# Some dimensions in `replicas` will become unknown after we conditionally
# return the real tensors or the dummy tensors. We fix the input shapes by
# using the shapes from `out_of_range_replicas` because it is calling
# get_next() inside.
flattened_replicas = nest.flatten(replicas)
for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
for target, source in zip(
nest.flatten(flattened_replicas[i], expand_composites=True),
nest.flatten(replica_data, expand_composites=True)):
target.set_shape(source.get_shape())
# `SparseTensor` shape is not determined by the shape of its component
# tensors. Rather, its shape depends on a tensor's values.
if sparse_tensor.is_sparse(replica_data) and replica_data.get_shape():
dense_shape = replica_data.get_shape()
with ops.device(flattened_replicas[i].op.device):
# For partially defined shapes, fill in missing values from tensor.
if not dense_shape.is_fully_defined():
dense_shape = array_ops.stack([
flattened_replicas[i].dense_shape[j] if dim is None else dim
for j, dim in enumerate(dense_shape.as_list())
])
flattened_replicas[i] = sparse_tensor.SparseTensor(
indices=flattened_replicas[i].indices,
values=flattened_replicas[i].values,
dense_shape=dense_shape)
replicas = nest.pack_sequence_as(replicas, flattened_replicas)
return values.regroup(replicas)
@ -1048,6 +1021,34 @@ def _dummy_tensor_fn(value_structure):
return nest.map_structure(create_dummy_tensor, value_structure)
def _recover_shape_fn(data, value_structure):
"""Recover the shape of `data` the same as shape of `value_structure`."""
flattened_data = nest.flatten(data)
for i, spec in enumerate(nest.flatten(value_structure)):
for target, source in zip(
nest.flatten(flattened_data[i], expand_composites=True),
nest.flatten(spec, expand_composites=True)):
target.set_shape(source.shape)
# `SparseTensor` shape is not determined by the shape of its component
# tensors. Rather, its shape depends on a tensor's values.
if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape:
dense_shape = spec.shape
with ops.device(flattened_data[i].op.device):
# For partially defined shapes, fill in missing values from tensor.
if not dense_shape.is_fully_defined():
dense_shape = array_ops.stack([
flattened_data[i].dense_shape[j] if dim is None else dim
for j, dim in enumerate(dense_shape.as_list())
])
flattened_data[i] = sparse_tensor.SparseTensor(
indices=flattened_data[i].indices,
values=flattened_data[i].values,
dense_shape=dense_shape)
data = nest.pack_sequence_as(data, flattened_data)
return data
class _SingleWorkerDatasetIteratorBase(object):
"""Iterator for a single `tf.data.Dataset`."""
@ -1132,6 +1133,13 @@ class _SingleWorkerDatasetIteratorBase(object):
lambda: _dummy_tensor_fn(data.value_structure),
strict=True,
)
# Some dimensions in `replicas` will become unknown after we
# conditionally return the real tensors or the dummy tensors. Recover
# the shapes from `data.value_structure`. We only need to do this in
# non eager mode because we always know the runtime shape of the
# tensors in eager mode.
if not context.executing_eagerly():
real_data = _recover_shape_fn(real_data, data.value_structure)
result.append(real_data)
# pylint: enable=cell-var-from-loop
# pylint: enable=unnecessary-lambda

View File

@ -48,6 +48,7 @@ from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
@ -612,6 +613,41 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
else:
self.assertAllEqual(first_epoch, second_epoch)
@combinations.generate(
combinations.combine(
mode=["eager"],
distribution=[
strategy_combinations.one_device_strategy,
strategy_combinations.mirrored_strategy_with_one_cpu,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
]))
def testGetNextOptionalShape(self, distribution):
batch_size = 8
dataset = dataset_ops.DatasetV2.from_tensor_slices({
"feature": array_ops.ones([batch_size, 10]),
"label": array_ops.ones([batch_size]),
})
dataset = dataset.batch(batch_size, drop_remainder=True)
dist_dataset = distribution.experimental_distribute_dataset(dataset)
per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
@def_function.function
def train_fn():
for data in dist_dataset:
data = nest.map_structure(distribution.experimental_local_results, data)
feature = data["feature"]
label = data["label"]
# Asser the shapes are still staic from all replicas.
for replica_id in range(distribution.num_replicas_in_sync):
self.assertEqual([per_replica_batch_size, 10],
feature[replica_id].shape)
self.assertEqual([per_replica_batch_size], label[replica_id].shape)
train_fn()
class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
parameterized.TestCase):