Move the recovering tensor shape logic of DistributedIterator get_next_as_optional earlier, so it covers more use cases.
PiperOrigin-RevId: 314277550 Change-Id: I2c8f3afcc791b02310b5242251f335d9da7cd5bf
This commit is contained in:
parent
0482d18b52
commit
bc49458b14
tensorflow/python/distribute
@ -349,33 +349,6 @@ class DistributedIteratorBase(distribute_types.Iterator):
|
||||
results.append(result)
|
||||
replicas = results
|
||||
|
||||
# Some dimensions in `replicas` will become unknown after we conditionally
|
||||
# return the real tensors or the dummy tensors. We fix the input shapes by
|
||||
# using the shapes from `out_of_range_replicas` because it is calling
|
||||
# get_next() inside.
|
||||
flattened_replicas = nest.flatten(replicas)
|
||||
for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)):
|
||||
for target, source in zip(
|
||||
nest.flatten(flattened_replicas[i], expand_composites=True),
|
||||
nest.flatten(replica_data, expand_composites=True)):
|
||||
target.set_shape(source.get_shape())
|
||||
# `SparseTensor` shape is not determined by the shape of its component
|
||||
# tensors. Rather, its shape depends on a tensor's values.
|
||||
if sparse_tensor.is_sparse(replica_data) and replica_data.get_shape():
|
||||
dense_shape = replica_data.get_shape()
|
||||
with ops.device(flattened_replicas[i].op.device):
|
||||
# For partially defined shapes, fill in missing values from tensor.
|
||||
if not dense_shape.is_fully_defined():
|
||||
dense_shape = array_ops.stack([
|
||||
flattened_replicas[i].dense_shape[j] if dim is None else dim
|
||||
for j, dim in enumerate(dense_shape.as_list())
|
||||
])
|
||||
flattened_replicas[i] = sparse_tensor.SparseTensor(
|
||||
indices=flattened_replicas[i].indices,
|
||||
values=flattened_replicas[i].values,
|
||||
dense_shape=dense_shape)
|
||||
replicas = nest.pack_sequence_as(replicas, flattened_replicas)
|
||||
|
||||
return values.regroup(replicas)
|
||||
|
||||
|
||||
@ -1048,6 +1021,34 @@ def _dummy_tensor_fn(value_structure):
|
||||
return nest.map_structure(create_dummy_tensor, value_structure)
|
||||
|
||||
|
||||
def _recover_shape_fn(data, value_structure):
|
||||
"""Recover the shape of `data` the same as shape of `value_structure`."""
|
||||
|
||||
flattened_data = nest.flatten(data)
|
||||
for i, spec in enumerate(nest.flatten(value_structure)):
|
||||
for target, source in zip(
|
||||
nest.flatten(flattened_data[i], expand_composites=True),
|
||||
nest.flatten(spec, expand_composites=True)):
|
||||
target.set_shape(source.shape)
|
||||
# `SparseTensor` shape is not determined by the shape of its component
|
||||
# tensors. Rather, its shape depends on a tensor's values.
|
||||
if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape:
|
||||
dense_shape = spec.shape
|
||||
with ops.device(flattened_data[i].op.device):
|
||||
# For partially defined shapes, fill in missing values from tensor.
|
||||
if not dense_shape.is_fully_defined():
|
||||
dense_shape = array_ops.stack([
|
||||
flattened_data[i].dense_shape[j] if dim is None else dim
|
||||
for j, dim in enumerate(dense_shape.as_list())
|
||||
])
|
||||
flattened_data[i] = sparse_tensor.SparseTensor(
|
||||
indices=flattened_data[i].indices,
|
||||
values=flattened_data[i].values,
|
||||
dense_shape=dense_shape)
|
||||
data = nest.pack_sequence_as(data, flattened_data)
|
||||
return data
|
||||
|
||||
|
||||
class _SingleWorkerDatasetIteratorBase(object):
|
||||
"""Iterator for a single `tf.data.Dataset`."""
|
||||
|
||||
@ -1132,6 +1133,13 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
lambda: _dummy_tensor_fn(data.value_structure),
|
||||
strict=True,
|
||||
)
|
||||
# Some dimensions in `replicas` will become unknown after we
|
||||
# conditionally return the real tensors or the dummy tensors. Recover
|
||||
# the shapes from `data.value_structure`. We only need to do this in
|
||||
# non eager mode because we always know the runtime shape of the
|
||||
# tensors in eager mode.
|
||||
if not context.executing_eagerly():
|
||||
real_data = _recover_shape_fn(real_data, data.value_structure)
|
||||
result.append(real_data)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
# pylint: enable=unnecessary-lambda
|
||||
|
@ -48,6 +48,7 @@ from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
@ -612,6 +613,41 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase,
|
||||
else:
|
||||
self.assertAllEqual(first_epoch, second_epoch)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.one_device_strategy,
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
]))
|
||||
def testGetNextOptionalShape(self, distribution):
|
||||
batch_size = 8
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices({
|
||||
"feature": array_ops.ones([batch_size, 10]),
|
||||
"label": array_ops.ones([batch_size]),
|
||||
})
|
||||
dataset = dataset.batch(batch_size, drop_remainder=True)
|
||||
dist_dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
per_replica_batch_size = batch_size // distribution.num_replicas_in_sync
|
||||
|
||||
@def_function.function
|
||||
def train_fn():
|
||||
for data in dist_dataset:
|
||||
data = nest.map_structure(distribution.experimental_local_results, data)
|
||||
feature = data["feature"]
|
||||
label = data["label"]
|
||||
|
||||
# Asser the shapes are still staic from all replicas.
|
||||
for replica_id in range(distribution.num_replicas_in_sync):
|
||||
self.assertEqual([per_replica_batch_size, 10],
|
||||
feature[replica_id].shape)
|
||||
self.assertEqual([per_replica_batch_size], label[replica_id].shape)
|
||||
|
||||
train_fn()
|
||||
|
||||
|
||||
class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase,
|
||||
parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user