Merge pull request #44632 from kushanam:keras_distribute_lib
PiperOrigin-RevId: 342159822 Change-Id: I36d2f359ede1ed853a279fe705f79be8860ad5b8
This commit is contained in:
commit
d2a8bec3cd
tensorflow/python/distribute
@ -28,6 +28,7 @@ from tensorflow.python.data.experimental.ops import batching
|
||||
from tensorflow.python.data.experimental.ops import cardinality
|
||||
from tensorflow.python.data.experimental.ops import distribute
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||
from tensorflow.python.data.ops import optional_ops
|
||||
from tensorflow.python.distribute import device_util
|
||||
@ -759,11 +760,11 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
|
||||
__slots__ = [
|
||||
"_input_workers", "_element_spec", "_strategy",
|
||||
"_enable_get_next_as_optional"
|
||||
"_enable_get_next_as_optional", "_options"
|
||||
]
|
||||
|
||||
def __init__(self, input_workers, element_spec, strategy,
|
||||
enable_get_next_as_optional):
|
||||
enable_get_next_as_optional, options):
|
||||
# We don't want to allow deserialization of this class because we don't
|
||||
# serialize the strategy object. Currently the only places where
|
||||
# _deserialize is called is when we save/restore using SavedModels.
|
||||
@ -775,6 +776,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
self._element_spec = element_spec
|
||||
self._strategy = strategy
|
||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
||||
self._options = options
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
@ -783,8 +785,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
def _serialize(self):
|
||||
# We cannot serialize the strategy object so we convert it to an id that we
|
||||
# can use for comparison.
|
||||
return (self._input_workers.serialize(),
|
||||
self._element_spec, id(self._strategy))
|
||||
return (self._input_workers.serialize(), self._element_spec,
|
||||
id(self._strategy), id(self._options))
|
||||
|
||||
def _deserialize(self):
|
||||
raise ValueError("Deserialization is currently unsupported for "
|
||||
@ -816,7 +818,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
other._element_spec)
|
||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
self._enable_get_next_as_optional,
|
||||
self._options)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
@ -826,9 +829,9 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
|
||||
element_spec = nest.map_structure(
|
||||
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
|
||||
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
|
||||
compute_devices,
|
||||
element_spec))
|
||||
specs.append(
|
||||
_SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
|
||||
element_spec, self._options))
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
@ -841,14 +844,16 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
components=components,
|
||||
element_spec=self._element_spec,
|
||||
strategy=self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional,
|
||||
options=self._options)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
# pylint: disable=protected-access
|
||||
return DistributedIteratorSpec(value._input_workers, value._element_spec,
|
||||
value._strategy,
|
||||
value._enable_get_next_as_optional)
|
||||
value._enable_get_next_as_optional,
|
||||
value._options)
|
||||
|
||||
def _with_tensor_ranks_only(self):
|
||||
element_spec = nest.map_structure(
|
||||
@ -856,7 +861,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||
self._element_spec)
|
||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
self._enable_get_next_as_optional,
|
||||
self._options)
|
||||
|
||||
|
||||
class DistributedIterator(DistributedIteratorBase,
|
||||
@ -869,7 +875,8 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
strategy=None,
|
||||
components=None,
|
||||
element_spec=None,
|
||||
enable_get_next_as_optional=False):
|
||||
enable_get_next_as_optional=False,
|
||||
options=None):
|
||||
if input_workers is None:
|
||||
raise ValueError("`input_workers` should be "
|
||||
"provided.")
|
||||
@ -877,6 +884,7 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
error_message = ("Either `input_workers` or "
|
||||
"both `components` and `element_spec` need to be "
|
||||
"provided.")
|
||||
self._options = options
|
||||
|
||||
if iterators is None:
|
||||
if (components is None or element_spec is None):
|
||||
@ -916,7 +924,8 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
# TODO(b/163362689): remove the comment after the bug if fixed.
|
||||
return DistributedIteratorSpec(self._input_workers, self._element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
self._enable_get_next_as_optional,
|
||||
self._options)
|
||||
|
||||
|
||||
class _IterableInput(DistributedDatasetInterface):
|
||||
@ -1290,7 +1299,8 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
input_workers=self._input_workers,
|
||||
iterators=iterators,
|
||||
strategy=self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional,
|
||||
options=self._options)
|
||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||
|
||||
# When async eager is enabled, sometimes the iterator may not finish
|
||||
@ -1585,9 +1595,7 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
"""Get next element for the given device."""
|
||||
del name
|
||||
with ops.device(self._worker):
|
||||
if isinstance(self._iterator,
|
||||
(multi_device_iterator_ops.OwnedMultiDeviceIterator,
|
||||
multi_device_iterator_ops.MultiDeviceIterator)):
|
||||
if _should_use_multi_device_iterator(self._options):
|
||||
return self._iterator.get_next(device)
|
||||
else:
|
||||
return self._iterator.get_next()
|
||||
@ -1665,25 +1673,30 @@ class _SingleWorkerDatasetIteratorBase(object):
|
||||
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
||||
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
|
||||
|
||||
__slots__ = ["_worker", "_devices", "_element_spec"]
|
||||
__slots__ = ["_worker", "_devices", "_element_spec", "_options"]
|
||||
|
||||
def __init__(self, worker, devices, element_spec):
|
||||
def __init__(self, worker, devices, element_spec, options):
|
||||
self._worker = worker
|
||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||
self._element_spec = element_spec
|
||||
self._options = options
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return _SingleWorkerOwnedDatasetIterator
|
||||
|
||||
def _serialize(self):
|
||||
return (self._worker, self._devices, self._element_spec)
|
||||
return (self._worker, self._devices, self._element_spec, self._options)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
specs = []
|
||||
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
|
||||
self._devices, self._worker, element_spec=self._element_spec))
|
||||
if _should_use_multi_device_iterator(self._options):
|
||||
specs.append(
|
||||
multi_device_iterator_ops.MultiDeviceIteratorSpec(
|
||||
self._devices, self._worker, element_spec=self._element_spec))
|
||||
else:
|
||||
specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
@ -1695,13 +1708,14 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
||||
worker=self._worker,
|
||||
devices=self._devices,
|
||||
components=components,
|
||||
element_spec=self._element_spec)
|
||||
element_spec=self._element_spec,
|
||||
options=self._options)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
# pylint: disable=protected-access
|
||||
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
|
||||
value._element_spec)
|
||||
value._element_spec, value._options)
|
||||
|
||||
|
||||
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
@ -1758,10 +1772,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
if not self._worker:
|
||||
raise ValueError("Worked device must be specified when creating an "
|
||||
"owned iterator.")
|
||||
if (self._options is None or self._options.experimental_replication_mode ==
|
||||
InputReplicationMode.PER_WORKER or
|
||||
(self._options.experimental_replication_mode == InputReplicationMode
|
||||
.PER_REPLICA and self._options.experimental_prefetch_to_device)):
|
||||
if _should_use_multi_device_iterator(self._options):
|
||||
host_device = device_util.get_host_for_device(self._worker)
|
||||
with ops.device(self._worker):
|
||||
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
|
||||
@ -1777,7 +1788,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
|
||||
self._element_spec)
|
||||
self._element_spec, self._options)
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
@ -1908,7 +1919,7 @@ def _create_iterators_per_worker(worker_datasets,
|
||||
options=options)
|
||||
else:
|
||||
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
|
||||
worker_devices)
|
||||
worker_devices, options)
|
||||
iterators.append(iterator)
|
||||
return iterators
|
||||
|
||||
@ -1989,6 +2000,17 @@ def _get_dataset_attributes(dataset):
|
||||
return batch_size, drop_remainder, prefetch_buffer
|
||||
|
||||
|
||||
def _should_use_multi_device_iterator(options):
|
||||
"""Determine whether to use multi_device_iterator_ops."""
|
||||
if (options is None or
|
||||
options.experimental_replication_mode == InputReplicationMode.PER_WORKER
|
||||
or
|
||||
(options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
|
||||
and options.experimental_prefetch_to_device)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class MultiStepContext(object):
|
||||
"""A context object that can be used to capture things when running steps.
|
||||
|
||||
|
@ -450,6 +450,76 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
tf_api_version=2,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False],
|
||||
experimental_place_dataset_on_device=[True, False],
|
||||
experimental_prefetch_to_device=[True, False],
|
||||
))
|
||||
def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(
|
||||
self, distribution, enable_get_next_as_optional,
|
||||
experimental_place_dataset_on_device, experimental_prefetch_to_device):
|
||||
|
||||
if experimental_place_dataset_on_device and experimental_prefetch_to_device:
|
||||
self.skipTest("Setting experimental_place_dataset_on_device and "
|
||||
"experimental_prefetch_to_device to `True` is not "
|
||||
"allowed when using "
|
||||
"distribute_lib.InputReplicationMode.PER_REPLICA.")
|
||||
|
||||
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||
_create_text_file(fname1, 5)
|
||||
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||
_create_text_file(fname2, 9)
|
||||
|
||||
def dataset_fn(input_context):
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
|
||||
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
return readers.TextLineDatasetV2(dataset).map(
|
||||
string_ops.string_to_number).batch(
|
||||
input_context.get_per_replica_batch_size(4))
|
||||
|
||||
options = distribute_lib.InputOptions(
|
||||
experimental_place_dataset_on_device=
|
||||
experimental_place_dataset_on_device,
|
||||
experimental_prefetch_to_device=experimental_prefetch_to_device,
|
||||
experimental_replication_mode=(
|
||||
distribute_lib.InputReplicationMode.PER_REPLICA))
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
ds = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn, options)
|
||||
|
||||
iterator = iter(ds)
|
||||
_check_type_spec_structure(iterator)
|
||||
spec = iterator._type_spec
|
||||
tensor_list = spec._to_components(iterator)
|
||||
re_iterator = spec._from_components(tensor_list)
|
||||
|
||||
_check_type_spec_structure(iter(ds))
|
||||
element_spec = ds.element_spec
|
||||
iter_element_spec = iter(ds).element_spec
|
||||
nest.assert_same_structure(element_spec, iter_element_spec)
|
||||
self.assertAllEqual(
|
||||
nest.flatten(element_spec), nest.flatten(iter_element_spec))
|
||||
self.assertEqual(iterator._input_workers, re_iterator._input_workers)
|
||||
self.assertAllEqual(iterator._iterators, re_iterator._iterators)
|
||||
|
||||
@def_function.function(input_signature=[element_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
for x in ds:
|
||||
process_inputs(x)
|
||||
|
||||
|
||||
class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user