Add InputOptions to experimental_distribute_dataset(s_from_function).

PiperOrigin-RevId: 316563848
Change-Id: I00d54d309395754a6182829725f42e1f968f14c4
This commit is contained in:
Bruce Fontaine 2020-06-15 16:01:40 -07:00 committed by TensorFlower Gardener
parent 7292433984
commit d29d8af754
21 changed files with 186 additions and 76 deletions

View File

@ -409,7 +409,7 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
num_replicas_in_sync=self._num_replicas_in_sync)
return input_context
def _experimental_distribute_dataset(self, dataset):
def _experimental_distribute_dataset(self, dataset, options):
input_context = self._make_input_context()
return input_lib.get_distributed_dataset(
dataset,
@ -418,7 +418,8 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
split_batch_by=self._num_replicas_in_sync,
input_context=input_context)
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
input_context = self._make_input_context()
return input_lib.get_distributed_datasets_from_function(
dataset_fn=dataset_fn,

View File

@ -602,6 +602,43 @@ class RunOptions(
cls).__new__(cls, experimental_enable_dynamic_batch_size,
experimental_bucketizing_dynamic_shape)
@tf_export("distribute.InputOptions", v1=[])
class InputOptions(
collections.namedtuple("InputOptions", [
"experimental_prefetch_to_device",
])):
"""Run options for `experimental_distribute_dataset(s_from_function)`.
This can be used to hold some strategy specific configs.
```python
# Setup TPUStrategy
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
dataset = tf.data.Dataset.range(16)
distributed_dataset_on_host = (
strategy.experimental_distribute_dataset(
dataset,
tf.distribute.InputOptions(
experimental_prefetch_to_device=False)))
```
Attributes:
experimental_prefetch_to_device: Boolean. Currently only applies to
TPUStrategy. Defaults to True. If True, dataset elements will be
prefetched to accelerator device memory. When False, dataset elements are
prefetched to host device memory. Must be False when using TPUEmbedding
API.
"""
def __new__(cls, experimental_prefetch_to_device=True):
return super(InputOptions, cls).__new__(cls,
experimental_prefetch_to_device)
# ------------------------------------------------------------------------------
# Base classes for all distribution strategies.
@ -821,7 +858,7 @@ class StrategyBase(object):
args = (input_iterator.get_next(),) if input_iterator is not None else ()
return self.run(fn, args=args)
def experimental_distribute_dataset(self, dataset):
def experimental_distribute_dataset(self, dataset, options=None):
"""Distributes a tf.data.Dataset instance provided via `dataset`.
The returned distributed dataset can be iterated over similar to how
@ -910,14 +947,17 @@ class StrategyBase(object):
Args:
dataset: `tf.data.Dataset` that will be sharded across all replicas using
the rules stated above.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A "distributed `Dataset`", which acts like a `tf.data.Dataset` except
it produces "per-replica" values.
"""
return self._extended._experimental_distribute_dataset(dataset) # pylint: disable=protected-access
return self._extended._experimental_distribute_dataset(dataset, options) # pylint: disable=protected-access
def experimental_distribute_datasets_from_function(self, dataset_fn):
def experimental_distribute_datasets_from_function(self, dataset_fn,
options=None):
"""Distributes `tf.data.Dataset` instances created by calls to `dataset_fn`.
`dataset_fn` will be called once for each worker in the strategy. Each
@ -973,13 +1013,15 @@ class StrategyBase(object):
Args:
dataset_fn: A function taking a `tf.distribute.InputContext` instance and
returning a `tf.data.Dataset`.
options: `tf.distribute.InputOptions` used to control options on how this
dataset is distributed.
Returns:
A "distributed `Dataset`", which acts like a `tf.data.Dataset` except
it produces "per-replica" values.
"""
return self._extended._experimental_distribute_datasets_from_function( # pylint: disable=protected-access
dataset_fn)
dataset_fn, options)
def run(self, fn, args=(), kwargs=None, options=None):
"""Run `fn` on each replica, with the given arguments.
@ -1943,10 +1985,11 @@ class StrategyExtendedV2(object):
def _make_input_fn_iterator(self, input_fn, replication_mode):
raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_dataset(self, dataset):
def _experimental_distribute_dataset(self, dataset, options):
raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_values_from_function(self, value_fn):
@ -2693,10 +2736,11 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
def variable_created_in_scope(self, v):
return v._distribute_strategy is None # pylint: disable=protected-access
def _experimental_distribute_dataset(self, dataset):
def _experimental_distribute_dataset(self, dataset, options):
return dataset
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
return dataset_fn(InputContext())
def _experimental_distribute_values_from_function(self, value_fn):

View File

@ -89,7 +89,8 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
[distribute_lib.InputContext()],
self._container_strategy())
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
return dataset_fn(distribute_lib.InputContext())
def _local_results(self, value):

View File

@ -476,7 +476,7 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
input_contexts,
self._container_strategy())
def _experimental_distribute_dataset(self, dataset):
def _experimental_distribute_dataset(self, dataset, options):
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
@ -487,7 +487,8 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._host_input_device, session)
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
input_contexts = []
num_workers = self._input_workers.num_workers
for i in range(num_workers):

View File

@ -297,13 +297,14 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
del destinations
return tensor
def _experimental_distribute_dataset(self, dataset):
def _experimental_distribute_dataset(self, dataset, options):
# Note that split_batch_by argument is not passed because it is always 1 in
# this strategy, and adding it adds unnecessary overhead to the dataset.
return input_lib.get_distributed_dataset(dataset, self._input_workers,
self._container_strategy())
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
self._input_workers,

View File

@ -337,7 +337,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils.validate_colocate(colocate_with_variable, self)
def _experimental_distribute_dataset(self, dataset):
def _experimental_distribute_dataset(self, dataset, options):
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
@ -376,7 +376,8 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
return numpy_dataset.one_host_numpy_dataset(
numpy_input, self._input_host_device, session)
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
if self._cluster_spec:
input_pipeline_id = multi_worker_util.id_in_cluster(
self._cluster_spec, self._task_type, self._task_id)

View File

@ -308,13 +308,14 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# device 0 for each replica.
# TODO(cjfj): Create `InputWorkers` lazily, allowing users to place the
# input onto a different logical device?
input_worker_devices = collections.OrderedDict()
self._device_input_worker_devices = collections.OrderedDict()
self._host_input_worker_devices = collections.OrderedDict()
for tpu_device in self._tpu_devices[:, 0]:
host_device = device_util.get_host_for_device(tpu_device)
input_worker_devices.setdefault(host_device, [])
input_worker_devices[host_device].append(tpu_device)
self._input_worker_devices = tuple(input_worker_devices.items())
self._input_workers_obj = None
self._device_input_worker_devices.setdefault(host_device, [])
self._device_input_worker_devices[host_device].append(tpu_device)
self._host_input_worker_devices.setdefault(host_device, [])
self._host_input_worker_devices[host_device].append(host_device)
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
@ -322,7 +323,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._require_static_shapes = True
self.experimental_enable_get_next_as_optional = True
self._prefetch_on_host = False
self._prefetch_to_device = True
self._logical_device_stack = [0]
@ -339,38 +340,18 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
# memory and b) TPU Embedding enqueue operation are CPU ops and this avoids
# a copy back to the host for dense tensors
def _set_prefetch_on_host(self, value):
if self._prefetch_on_host == value:
return
if self._input_workers_obj is not None:
raise RuntimeError("Unable to change prefetch on host behavior as "
"InputWorkers are already created.")
self._prefetch_on_host = value
if value:
# To prefetch on the host, we must set all the input worker devices to the
# corresponding host devices.
self._input_worker_devices = tuple([
tuple([host,
[device_util.get_host_for_device(d) for d in devices]])
for host, devices in self._input_worker_devices])
# Force creation of the workers.
workers = self._input_workers
del workers
@property
def _input_workers(self):
if self._input_workers_obj is None:
self._input_workers_obj = input_lib.InputWorkers(
self._input_worker_devices)
return self._input_workers_obj
self._prefetch_to_device = not value
def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils. validate_colocate(colocate_with_variable, self)
def _make_dataset_iterator(self, dataset):
"""Make iterators for each of the TPU hosts."""
input_workers = input_lib.InputWorkers(
tuple(self._device_input_worker_devices.items()))
return input_lib.DatasetIterator(
dataset,
self._input_workers,
input_workers,
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
@ -379,7 +360,9 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
input_fn,
replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
input_contexts = []
num_workers = self._input_workers.num_workers
input_workers = input_lib.InputWorkers(
tuple(self._device_input_worker_devices.items()))
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
@ -387,7 +370,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib.InputFunctionIterator(
input_fn,
self._input_workers,
input_workers,
input_contexts,
self._container_strategy())
@ -396,16 +379,29 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
numpy_input, numpy_dataset.SingleDevice(self._host_device),
session)
def _experimental_distribute_dataset(self, dataset):
def _get_input_workers(self, options):
prefetch_to_device = self._prefetch_to_device
if options:
prefetch_to_device = options.experimental_prefetch_to_device
if prefetch_to_device:
return input_lib.InputWorkers(
tuple(self._device_input_worker_devices.items()))
else:
return input_lib.InputWorkers(
tuple(self._host_input_worker_devices.items()))
def _experimental_distribute_dataset(self, dataset, options):
return input_lib.get_distributed_dataset(
dataset,
self._input_workers,
self._get_input_workers(options),
self._container_strategy(),
split_batch_by=self._num_replicas_in_sync)
def _experimental_distribute_datasets_from_function(self, dataset_fn):
def _experimental_distribute_datasets_from_function(self, dataset_fn,
options):
input_workers = self._get_input_workers(options)
input_contexts = []
num_workers = self._input_workers.num_workers
num_workers = input_workers.num_workers
for i in range(num_workers):
input_contexts.append(distribute_lib.InputContext(
num_input_pipelines=num_workers,
@ -414,7 +410,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
return input_lib.get_distributed_datasets_from_function(
dataset_fn,
self._input_workers,
input_workers,
input_contexts,
self._container_strategy())

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import tpu_strategy as tpu_lib
@ -30,6 +31,7 @@ from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -546,6 +548,46 @@ class TPUStrategyTest(test.TestCase):
update_variable.get_concrete_function()
self.assertEqual(trace_count[0], len(strategy.extended.worker_devices))
def test_prefetch_to_device_default(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
# Check default, should prefetch to TPU.
dataset_item = next(iter(strategy.experimental_distribute_dataset(dataset)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "TPU")
def test_prefetch_to_device_tpu(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=True)
dataset_item = next(iter(strategy.experimental_distribute_dataset(
dataset, options=input_options)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "TPU")
def test_prefetch_to_device_cpu(self):
strategy = get_tpu_strategy()
dataset = dataset_ops.Dataset.range(
strategy.num_replicas_in_sync * 2,
output_type=dtypes.float32).batch(strategy.num_replicas_in_sync)
# Should be CPU when prefetch_to_device is False.
input_options = distribute_lib.InputOptions(
experimental_prefetch_to_device=False)
dataset_item = next(iter(strategy.experimental_distribute_dataset(
dataset, options=input_options)))
dataset_location = tf_device.DeviceSpec.from_string(
dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "CPU")
if __name__ == "__main__":
test.main()

View File

@ -26,11 +26,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_local_results"

View File

@ -26,11 +26,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_local_results"

View File

@ -25,11 +25,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_local_results"

View File

@ -26,11 +26,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_local_results"

View File

@ -26,11 +26,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_local_results"

View File

@ -26,11 +26,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_local_results"

View File

@ -30,11 +30,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_local_results"

View File

@ -0,0 +1,19 @@
path: "tensorflow.distribute.InputOptions"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.InputOptions\'>"
is_instance: "<type \'tuple\'>"
member {
name: "experimental_prefetch_to_device"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "count"
}
member_method {
name: "index"
}
}

View File

@ -30,11 +30,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"

View File

@ -29,11 +29,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"

View File

@ -30,11 +30,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"

View File

@ -30,11 +30,11 @@ tf_class {
}
member_method {
name: "experimental_distribute_dataset"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'dataset_fn\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "experimental_distribute_values_from_function"

View File

@ -16,6 +16,10 @@ tf_module {
name: "InputContext"
mtype: "<type \'type\'>"
}
member {
name: "InputOptions"
mtype: "<type \'type\'>"
}
member {
name: "InputReplicationMode"
mtype: "<class \'enum.EnumMeta\'>"