Merge pull request #44998 from geetachavan1/cherrypicks_KXZX5
[CherryPick:r2.4] add per_replica support for keras
This commit is contained in:
commit
c34eac578e
@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import cardinality
|
|||||||
from tensorflow.python.data.experimental.ops import distribute
|
from tensorflow.python.data.experimental.ops import distribute
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import multi_device_iterator_ops
|
from tensorflow.python.data.ops import multi_device_iterator_ops
|
||||||
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.data.ops import optional_ops
|
from tensorflow.python.data.ops import optional_ops
|
||||||
from tensorflow.python.distribute import device_util
|
from tensorflow.python.distribute import device_util
|
||||||
from tensorflow.python.distribute import distribute_utils
|
from tensorflow.python.distribute import distribute_utils
|
||||||
@ -759,11 +760,11 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
|
|
||||||
__slots__ = [
|
__slots__ = [
|
||||||
"_input_workers", "_element_spec", "_strategy",
|
"_input_workers", "_element_spec", "_strategy",
|
||||||
"_enable_get_next_as_optional"
|
"_enable_get_next_as_optional", "_options"
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(self, input_workers, element_spec, strategy,
|
def __init__(self, input_workers, element_spec, strategy,
|
||||||
enable_get_next_as_optional):
|
enable_get_next_as_optional, options):
|
||||||
# We don't want to allow deserialization of this class because we don't
|
# We don't want to allow deserialization of this class because we don't
|
||||||
# serialize the strategy object. Currently the only places where
|
# serialize the strategy object. Currently the only places where
|
||||||
# _deserialize is called is when we save/restore using SavedModels.
|
# _deserialize is called is when we save/restore using SavedModels.
|
||||||
@ -775,6 +776,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
self._element_spec = element_spec
|
self._element_spec = element_spec
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
self._enable_get_next_as_optional = enable_get_next_as_optional
|
||||||
|
self._options = options
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self):
|
def value_type(self):
|
||||||
@ -784,7 +786,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
# We cannot serialize the strategy object so we convert it to an id that we
|
# We cannot serialize the strategy object so we convert it to an id that we
|
||||||
# can use for comparison.
|
# can use for comparison.
|
||||||
return (self._input_workers.serialize(),
|
return (self._input_workers.serialize(),
|
||||||
self._element_spec, id(self._strategy))
|
self._element_spec, id(self._strategy), id(self._options))
|
||||||
|
|
||||||
def _deserialize(self):
|
def _deserialize(self):
|
||||||
raise ValueError("Deserialization is currently unsupported for "
|
raise ValueError("Deserialization is currently unsupported for "
|
||||||
@ -816,7 +818,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
other._element_spec)
|
other._element_spec)
|
||||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||||
self._strategy,
|
self._strategy,
|
||||||
self._enable_get_next_as_optional)
|
self._enable_get_next_as_optional,
|
||||||
|
self._options)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _component_specs(self):
|
def _component_specs(self):
|
||||||
@ -828,7 +831,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
|
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
|
||||||
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
|
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
|
||||||
compute_devices,
|
compute_devices,
|
||||||
element_spec))
|
element_spec,
|
||||||
|
self._options))
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
def _to_components(self, value):
|
def _to_components(self, value):
|
||||||
@ -841,14 +845,16 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
components=components,
|
components=components,
|
||||||
element_spec=self._element_spec,
|
element_spec=self._element_spec,
|
||||||
strategy=self._strategy,
|
strategy=self._strategy,
|
||||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
enable_get_next_as_optional=self._enable_get_next_as_optional,
|
||||||
|
options=self._options)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_value(value):
|
def from_value(value):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return DistributedIteratorSpec(value._input_workers, value._element_spec,
|
return DistributedIteratorSpec(value._input_workers, value._element_spec,
|
||||||
value._strategy,
|
value._strategy,
|
||||||
value._enable_get_next_as_optional)
|
value._enable_get_next_as_optional,
|
||||||
|
value._options)
|
||||||
|
|
||||||
def _with_tensor_ranks_only(self):
|
def _with_tensor_ranks_only(self):
|
||||||
element_spec = nest.map_structure(
|
element_spec = nest.map_structure(
|
||||||
@ -856,7 +862,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
self._element_spec)
|
self._element_spec)
|
||||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||||
self._strategy,
|
self._strategy,
|
||||||
self._enable_get_next_as_optional)
|
self._enable_get_next_as_optional,
|
||||||
|
self._options)
|
||||||
|
|
||||||
|
|
||||||
class DistributedIterator(DistributedIteratorBase,
|
class DistributedIterator(DistributedIteratorBase,
|
||||||
@ -869,7 +876,8 @@ class DistributedIterator(DistributedIteratorBase,
|
|||||||
strategy=None,
|
strategy=None,
|
||||||
components=None,
|
components=None,
|
||||||
element_spec=None,
|
element_spec=None,
|
||||||
enable_get_next_as_optional=False):
|
enable_get_next_as_optional=False,
|
||||||
|
options=None):
|
||||||
if input_workers is None:
|
if input_workers is None:
|
||||||
raise ValueError("`input_workers` should be "
|
raise ValueError("`input_workers` should be "
|
||||||
"provided.")
|
"provided.")
|
||||||
@ -877,6 +885,7 @@ class DistributedIterator(DistributedIteratorBase,
|
|||||||
error_message = ("Either `input_workers` or "
|
error_message = ("Either `input_workers` or "
|
||||||
"both `components` and `element_spec` need to be "
|
"both `components` and `element_spec` need to be "
|
||||||
"provided.")
|
"provided.")
|
||||||
|
self._options = options
|
||||||
|
|
||||||
if iterators is None:
|
if iterators is None:
|
||||||
if (components is None or element_spec is None):
|
if (components is None or element_spec is None):
|
||||||
@ -916,7 +925,8 @@ class DistributedIterator(DistributedIteratorBase,
|
|||||||
# TODO(b/163362689): remove the comment after the bug if fixed.
|
# TODO(b/163362689): remove the comment after the bug if fixed.
|
||||||
return DistributedIteratorSpec(self._input_workers, self._element_spec,
|
return DistributedIteratorSpec(self._input_workers, self._element_spec,
|
||||||
self._strategy,
|
self._strategy,
|
||||||
self._enable_get_next_as_optional)
|
self._enable_get_next_as_optional,
|
||||||
|
self._options)
|
||||||
|
|
||||||
|
|
||||||
class _IterableInput(DistributedDatasetInterface):
|
class _IterableInput(DistributedDatasetInterface):
|
||||||
@ -1290,7 +1300,8 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
|||||||
input_workers=self._input_workers,
|
input_workers=self._input_workers,
|
||||||
iterators=iterators,
|
iterators=iterators,
|
||||||
strategy=self._strategy,
|
strategy=self._strategy,
|
||||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
enable_get_next_as_optional=self._enable_get_next_as_optional,
|
||||||
|
options=self._options)
|
||||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
# When async eager is enabled, sometimes the iterator may not finish
|
# When async eager is enabled, sometimes the iterator may not finish
|
||||||
@ -1585,9 +1596,7 @@ class _SingleWorkerDatasetIteratorBase(object):
|
|||||||
"""Get next element for the given device."""
|
"""Get next element for the given device."""
|
||||||
del name
|
del name
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
if isinstance(self._iterator,
|
if _should_use_multi_device_iterator(self._options):
|
||||||
(multi_device_iterator_ops.OwnedMultiDeviceIterator,
|
|
||||||
multi_device_iterator_ops.MultiDeviceIterator)):
|
|
||||||
return self._iterator.get_next(device)
|
return self._iterator.get_next(device)
|
||||||
else:
|
else:
|
||||||
return self._iterator.get_next()
|
return self._iterator.get_next()
|
||||||
@ -1665,25 +1674,30 @@ class _SingleWorkerDatasetIteratorBase(object):
|
|||||||
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
||||||
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
|
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
|
||||||
|
|
||||||
__slots__ = ["_worker", "_devices", "_element_spec"]
|
__slots__ = ["_worker", "_devices", "_element_spec", "_options"]
|
||||||
|
|
||||||
def __init__(self, worker, devices, element_spec):
|
def __init__(self, worker, devices, element_spec, options):
|
||||||
self._worker = worker
|
self._worker = worker
|
||||||
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
self._devices = tuple(device_util.canonicalize(d) for d in devices)
|
||||||
self._element_spec = element_spec
|
self._element_spec = element_spec
|
||||||
|
self._options = options
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self):
|
def value_type(self):
|
||||||
return _SingleWorkerOwnedDatasetIterator
|
return _SingleWorkerOwnedDatasetIterator
|
||||||
|
|
||||||
def _serialize(self):
|
def _serialize(self):
|
||||||
return (self._worker, self._devices, self._element_spec)
|
return (self._worker, self._devices, self._element_spec, self._options)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _component_specs(self):
|
def _component_specs(self):
|
||||||
specs = []
|
specs = []
|
||||||
|
if _should_use_multi_device_iterator(self._options):
|
||||||
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
|
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
|
||||||
self._devices, self._worker, element_spec=self._element_spec))
|
self._devices, self._worker, element_spec=self._element_spec))
|
||||||
|
else:
|
||||||
|
specs.append(iterator_ops.IteratorSpec(
|
||||||
|
element_spec=self._element_spec))
|
||||||
return specs
|
return specs
|
||||||
|
|
||||||
def _to_components(self, value):
|
def _to_components(self, value):
|
||||||
@ -1695,13 +1709,14 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
|
|||||||
worker=self._worker,
|
worker=self._worker,
|
||||||
devices=self._devices,
|
devices=self._devices,
|
||||||
components=components,
|
components=components,
|
||||||
element_spec=self._element_spec)
|
element_spec=self._element_spec,
|
||||||
|
options=self._options)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_value(value):
|
def from_value(value):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
|
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
|
||||||
value._element_spec)
|
value._element_spec, value._options)
|
||||||
|
|
||||||
|
|
||||||
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
||||||
@ -1758,10 +1773,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
|||||||
if not self._worker:
|
if not self._worker:
|
||||||
raise ValueError("Worked device must be specified when creating an "
|
raise ValueError("Worked device must be specified when creating an "
|
||||||
"owned iterator.")
|
"owned iterator.")
|
||||||
if (self._options is None or self._options.experimental_replication_mode ==
|
if _should_use_multi_device_iterator(self._options):
|
||||||
InputReplicationMode.PER_WORKER or
|
|
||||||
(self._options.experimental_replication_mode == InputReplicationMode
|
|
||||||
.PER_REPLICA and self._options.experimental_prefetch_to_device)):
|
|
||||||
host_device = device_util.get_host_for_device(self._worker)
|
host_device = device_util.get_host_for_device(self._worker)
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
|
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
|
||||||
@ -1777,7 +1789,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
|
|||||||
@property
|
@property
|
||||||
def _type_spec(self):
|
def _type_spec(self):
|
||||||
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
|
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
|
||||||
self._element_spec)
|
self._element_spec, self._options)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def output_classes(self):
|
def output_classes(self):
|
||||||
@ -1908,7 +1920,7 @@ def _create_iterators_per_worker(worker_datasets,
|
|||||||
options=options)
|
options=options)
|
||||||
else:
|
else:
|
||||||
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
|
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
|
||||||
worker_devices)
|
worker_devices, options)
|
||||||
iterators.append(iterator)
|
iterators.append(iterator)
|
||||||
return iterators
|
return iterators
|
||||||
|
|
||||||
@ -1988,6 +2000,16 @@ def _get_dataset_attributes(dataset):
|
|||||||
|
|
||||||
return batch_size, drop_remainder, prefetch_buffer
|
return batch_size, drop_remainder, prefetch_buffer
|
||||||
|
|
||||||
|
def _should_use_multi_device_iterator(options):
|
||||||
|
"""Determine whether to use multi_device_iterator_ops.OwnedMultiDeviceIterator"""
|
||||||
|
if (options is None or
|
||||||
|
options.experimental_replication_mode == InputReplicationMode.PER_WORKER
|
||||||
|
or
|
||||||
|
(options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
|
||||||
|
and options.experimental_prefetch_to_device)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class MultiStepContext(object):
|
class MultiStepContext(object):
|
||||||
"""A context object that can be used to capture things when running steps.
|
"""A context object that can be used to capture things when running steps.
|
||||||
|
@ -450,6 +450,75 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
|||||||
f(v)
|
f(v)
|
||||||
self.assertEqual(self.trace_count, 1)
|
self.assertEqual(self.trace_count, 1)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
mode=["eager"],
|
||||||
|
tf_api_version=2,
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_two_gpus,
|
||||||
|
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
],
|
||||||
|
enable_get_next_as_optional=[True, False],
|
||||||
|
experimental_place_dataset_on_device=[True, False],
|
||||||
|
experimental_prefetch_to_device=[True, False],))
|
||||||
|
def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(
|
||||||
|
self, distribution, enable_get_next_as_optional,
|
||||||
|
experimental_place_dataset_on_device,
|
||||||
|
experimental_prefetch_to_device):
|
||||||
|
|
||||||
|
if experimental_place_dataset_on_device and experimental_prefetch_to_device:
|
||||||
|
self.skipTest("Setting experimental_place_dataset_on_device and "
|
||||||
|
"experimental_prefetch_to_device to `True` is not "
|
||||||
|
"allowed when using "
|
||||||
|
"distribute_lib.InputReplicationMode.PER_REPLICA.")
|
||||||
|
|
||||||
|
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||||
|
_create_text_file(fname1, 5)
|
||||||
|
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||||
|
_create_text_file(fname2, 9)
|
||||||
|
|
||||||
|
def dataset_fn(input_context):
|
||||||
|
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
|
||||||
|
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||||
|
input_context.input_pipeline_id)
|
||||||
|
return readers.TextLineDatasetV2(dataset).map(
|
||||||
|
string_ops.string_to_number).batch(
|
||||||
|
input_context.get_per_replica_batch_size(4))
|
||||||
|
|
||||||
|
options = distribute_lib.InputOptions(
|
||||||
|
experimental_place_dataset_on_device=experimental_place_dataset_on_device,
|
||||||
|
experimental_prefetch_to_device=experimental_prefetch_to_device,
|
||||||
|
experimental_replication_mode=(
|
||||||
|
distribute_lib.InputReplicationMode.PER_REPLICA
|
||||||
|
))
|
||||||
|
|
||||||
|
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||||
|
enable_get_next_as_optional)
|
||||||
|
ds = distribution.experimental_distribute_datasets_from_function(dataset_fn,
|
||||||
|
options)
|
||||||
|
|
||||||
|
iterator = iter(ds)
|
||||||
|
_check_type_spec_structure(iterator)
|
||||||
|
spec = iterator._type_spec
|
||||||
|
tensor_list = spec._to_components(iterator)
|
||||||
|
re_iterator = spec._from_components(tensor_list)
|
||||||
|
|
||||||
|
_check_type_spec_structure(iter(ds))
|
||||||
|
element_spec = ds.element_spec
|
||||||
|
iter_element_spec = iter(ds).element_spec
|
||||||
|
nest.assert_same_structure(element_spec, iter_element_spec)
|
||||||
|
self.assertAllEqual(
|
||||||
|
nest.flatten(element_spec), nest.flatten(iter_element_spec))
|
||||||
|
self.assertEqual(iterator._input_workers, re_iterator._input_workers)
|
||||||
|
self.assertAllEqual(iterator._iterators, re_iterator._iterators)
|
||||||
|
|
||||||
|
@def_function.function(input_signature=[element_spec])
|
||||||
|
def process_inputs(inputs):
|
||||||
|
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||||
|
|
||||||
|
for x in ds:
|
||||||
|
process_inputs(x)
|
||||||
|
|
||||||
class RaggedTensorDistributedIteratorTest(test.TestCase,
|
class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||||
parameterized.TestCase):
|
parameterized.TestCase):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user