Merge pull request from kushanam:keras_distribute_lib

PiperOrigin-RevId: 342159822
Change-Id: I36d2f359ede1ed853a279fe705f79be8860ad5b8
This commit is contained in:
TensorFlower Gardener 2020-11-12 16:52:16 -08:00
commit d2a8bec3cd
2 changed files with 122 additions and 30 deletions
tensorflow/python/distribute

View File

@ -28,6 +28,7 @@ from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.experimental.ops import distribute
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.distribute import device_util
@ -759,11 +760,11 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
__slots__ = [
"_input_workers", "_element_spec", "_strategy",
"_enable_get_next_as_optional"
"_enable_get_next_as_optional", "_options"
]
def __init__(self, input_workers, element_spec, strategy,
enable_get_next_as_optional):
enable_get_next_as_optional, options):
# We don't want to allow deserialization of this class because we don't
# serialize the strategy object. Currently the only places where
# _deserialize is called is when we save/restore using SavedModels.
@ -775,6 +776,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
self._element_spec = element_spec
self._strategy = strategy
self._enable_get_next_as_optional = enable_get_next_as_optional
self._options = options
@property
def value_type(self):
@ -783,8 +785,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
def _serialize(self):
# We cannot serialize the strategy object so we convert it to an id that we
# can use for comparison.
return (self._input_workers.serialize(),
self._element_spec, id(self._strategy))
return (self._input_workers.serialize(), self._element_spec,
id(self._strategy), id(self._options))
def _deserialize(self):
raise ValueError("Deserialization is currently unsupported for "
@ -816,7 +818,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
other._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._enable_get_next_as_optional,
self._options)
@property
def _component_specs(self):
@ -826,9 +829,9 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
for i, (input_device, compute_devices) in enumerate(worker_device_pairs):
element_spec = nest.map_structure(
functools.partial(_replace_per_replica_spec, i=i), self._element_spec)
specs.append(_SingleWorkerDatasetIteratorSpec(input_device,
compute_devices,
element_spec))
specs.append(
_SingleWorkerDatasetIteratorSpec(input_device, compute_devices,
element_spec, self._options))
return specs
def _to_components(self, value):
@ -841,14 +844,16 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
components=components,
element_spec=self._element_spec,
strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return DistributedIteratorSpec(value._input_workers, value._element_spec,
value._strategy,
value._enable_get_next_as_optional)
value._enable_get_next_as_optional,
value._options)
def _with_tensor_ranks_only(self):
element_spec = nest.map_structure(
@ -856,7 +861,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
self._element_spec)
return DistributedIteratorSpec(self._input_workers, element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._enable_get_next_as_optional,
self._options)
class DistributedIterator(DistributedIteratorBase,
@ -869,7 +875,8 @@ class DistributedIterator(DistributedIteratorBase,
strategy=None,
components=None,
element_spec=None,
enable_get_next_as_optional=False):
enable_get_next_as_optional=False,
options=None):
if input_workers is None:
raise ValueError("`input_workers` should be "
"provided.")
@ -877,6 +884,7 @@ class DistributedIterator(DistributedIteratorBase,
error_message = ("Either `input_workers` or "
"both `components` and `element_spec` need to be "
"provided.")
self._options = options
if iterators is None:
if (components is None or element_spec is None):
@ -916,7 +924,8 @@ class DistributedIterator(DistributedIteratorBase,
# TODO(b/163362689): remove the comment after the bug if fixed.
return DistributedIteratorSpec(self._input_workers, self._element_spec,
self._strategy,
self._enable_get_next_as_optional)
self._enable_get_next_as_optional,
self._options)
class _IterableInput(DistributedDatasetInterface):
@ -1290,7 +1299,8 @@ class DistributedDatasetsFromFunction(_IterableInput):
input_workers=self._input_workers,
iterators=iterators,
strategy=self._strategy,
enable_get_next_as_optional=self._enable_get_next_as_optional)
enable_get_next_as_optional=self._enable_get_next_as_optional,
options=self._options)
iterator._element_spec = self._element_spec # pylint: disable=protected-access
# When async eager is enabled, sometimes the iterator may not finish
@ -1585,9 +1595,7 @@ class _SingleWorkerDatasetIteratorBase(object):
"""Get next element for the given device."""
del name
with ops.device(self._worker):
if isinstance(self._iterator,
(multi_device_iterator_ops.OwnedMultiDeviceIterator,
multi_device_iterator_ops.MultiDeviceIterator)):
if _should_use_multi_device_iterator(self._options):
return self._iterator.get_next(device)
else:
return self._iterator.get_next()
@ -1665,25 +1673,30 @@ class _SingleWorkerDatasetIteratorBase(object):
class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
"""Type specification for `_SingleWorkerOwnedDatasetIterator`."""
__slots__ = ["_worker", "_devices", "_element_spec"]
__slots__ = ["_worker", "_devices", "_element_spec", "_options"]
def __init__(self, worker, devices, element_spec):
def __init__(self, worker, devices, element_spec, options):
self._worker = worker
self._devices = tuple(device_util.canonicalize(d) for d in devices)
self._element_spec = element_spec
self._options = options
@property
def value_type(self):
return _SingleWorkerOwnedDatasetIterator
def _serialize(self):
return (self._worker, self._devices, self._element_spec)
return (self._worker, self._devices, self._element_spec, self._options)
@property
def _component_specs(self):
specs = []
specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec(
self._devices, self._worker, element_spec=self._element_spec))
if _should_use_multi_device_iterator(self._options):
specs.append(
multi_device_iterator_ops.MultiDeviceIteratorSpec(
self._devices, self._worker, element_spec=self._element_spec))
else:
specs.append(iterator_ops.IteratorSpec(element_spec=self._element_spec))
return specs
def _to_components(self, value):
@ -1695,13 +1708,14 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec):
worker=self._worker,
devices=self._devices,
components=components,
element_spec=self._element_spec)
element_spec=self._element_spec,
options=self._options)
@staticmethod
def from_value(value):
# pylint: disable=protected-access
return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices,
value._element_spec)
value._element_spec, value._options)
class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
@ -1758,10 +1772,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
if not self._worker:
raise ValueError("Worked device must be specified when creating an "
"owned iterator.")
if (self._options is None or self._options.experimental_replication_mode ==
InputReplicationMode.PER_WORKER or
(self._options.experimental_replication_mode == InputReplicationMode
.PER_REPLICA and self._options.experimental_prefetch_to_device)):
if _should_use_multi_device_iterator(self._options):
host_device = device_util.get_host_for_device(self._worker)
with ops.device(self._worker):
self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator(
@ -1777,7 +1788,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase,
@property
def _type_spec(self):
return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices,
self._element_spec)
self._element_spec, self._options)
@property
def output_classes(self):
@ -1908,7 +1919,7 @@ def _create_iterators_per_worker(worker_datasets,
options=options)
else:
iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
worker_devices)
worker_devices, options)
iterators.append(iterator)
return iterators
@ -1989,6 +2000,17 @@ def _get_dataset_attributes(dataset):
return batch_size, drop_remainder, prefetch_buffer
def _should_use_multi_device_iterator(options):
"""Determine whether to use multi_device_iterator_ops."""
if (options is None or
options.experimental_replication_mode == InputReplicationMode.PER_WORKER
or
(options.experimental_replication_mode == InputReplicationMode.PER_REPLICA
and options.experimental_prefetch_to_device)):
return True
return False
class MultiStepContext(object):
"""A context object that can be used to capture things when running steps.

View File

@ -450,6 +450,76 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
f(v)
self.assertEqual(self.trace_count, 1)
@combinations.generate(
combinations.combine(
mode=["eager"],
tf_api_version=2,
distribution=[
strategy_combinations.mirrored_strategy_with_two_gpus,
strategy_combinations.mirrored_strategy_with_cpu_1_and_2,
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
],
enable_get_next_as_optional=[True, False],
experimental_place_dataset_on_device=[True, False],
experimental_prefetch_to_device=[True, False],
))
def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(
self, distribution, enable_get_next_as_optional,
experimental_place_dataset_on_device, experimental_prefetch_to_device):
if experimental_place_dataset_on_device and experimental_prefetch_to_device:
self.skipTest("Setting experimental_place_dataset_on_device and "
"experimental_prefetch_to_device to `True` is not "
"allowed when using "
"distribute_lib.InputReplicationMode.PER_REPLICA.")
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
_create_text_file(fname1, 5)
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
_create_text_file(fname2, 9)
def dataset_fn(input_context):
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
return readers.TextLineDatasetV2(dataset).map(
string_ops.string_to_number).batch(
input_context.get_per_replica_batch_size(4))
options = distribute_lib.InputOptions(
experimental_place_dataset_on_device=
experimental_place_dataset_on_device,
experimental_prefetch_to_device=experimental_prefetch_to_device,
experimental_replication_mode=(
distribute_lib.InputReplicationMode.PER_REPLICA))
distribution.extended.experimental_enable_get_next_as_optional = (
enable_get_next_as_optional)
ds = distribution.experimental_distribute_datasets_from_function(
dataset_fn, options)
iterator = iter(ds)
_check_type_spec_structure(iterator)
spec = iterator._type_spec
tensor_list = spec._to_components(iterator)
re_iterator = spec._from_components(tensor_list)
_check_type_spec_structure(iter(ds))
element_spec = ds.element_spec
iter_element_spec = iter(ds).element_spec
nest.assert_same_structure(element_spec, iter_element_spec)
self.assertAllEqual(
nest.flatten(element_spec), nest.flatten(iter_element_spec))
self.assertEqual(iterator._input_workers, re_iterator._input_workers)
self.assertAllEqual(iterator._iterators, re_iterator._iterators)
@def_function.function(input_signature=[element_spec])
def process_inputs(inputs):
distribution.run(lambda inputs: inputs, args=(inputs,))
for x in ds:
process_inputs(x)
class RaggedTensorDistributedIteratorTest(test.TestCase,
parameterized.TestCase):