[retry]DistributedDataset creates elements with fixed spec to help avoid retracing
tf.function tracing depends on the inputs to the function. For a typical training loop: x, y = next(iter) train_fn(x,y) it may retrace when getting a partial/batches. This is problematic for multi client training since different client may retrace at different time. We assign collective instance_key when tracing a function, retracing results in different sets of instance keys. This change we overrides the PerReplica type spec, which is used to calculate function cache key. This tries to avoid retracing in common cases, but it doesn't guarantee that it won't happen. Note that after such change, the function also gets partial shape information. This is the reason we only do it for multi client strategies (MWMS), to avoid performance penalty to e.g. TPU. PiperOrigin-RevId: 338203534 Change-Id: Iae9d6c3c82113d623707e19142fbebe5597d7898
This commit is contained in:
parent
fbcdf129b9
commit
9f51b98f0b
@ -1018,22 +1018,24 @@ distribute_py_test(
|
||||
"multi_and_single_gpu",
|
||||
],
|
||||
deps = [
|
||||
":collective_all_reduce_strategy",
|
||||
":combinations",
|
||||
":input_lib",
|
||||
":mirrored_strategy",
|
||||
":multi_worker_test_base",
|
||||
":reduce_util",
|
||||
":distribute_lib",
|
||||
":strategy_combinations",
|
||||
":test_util",
|
||||
":tpu_strategy",
|
||||
":values",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/data/ops:readers",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor",
|
||||
"//third_party/py/numpy",
|
||||
|
@ -559,7 +559,8 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
|
||||
flattened_data = []
|
||||
for per_worker_data in replicas:
|
||||
flattened_data.extend(per_worker_data)
|
||||
replicas = distribute_utils.regroup(flattened_data)
|
||||
replicas = _create_per_replica(
|
||||
flattened_data, strategy, get_next_as_optional=True)
|
||||
|
||||
# Run an all-reduce to see whether any worker has values.
|
||||
# TODO(b/131423105): we should be able to short-cut the all-reduce in some
|
||||
@ -659,7 +660,8 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
# Make `replicas` a flat list of values across all replicas.
|
||||
replicas.extend(
|
||||
self._iterators[i].get_next_as_list_static_shapes(new_name))
|
||||
return distribute_utils.regroup(replicas)
|
||||
return _create_per_replica(
|
||||
replicas, self._strategy, get_next_as_optional=False)
|
||||
|
||||
out_of_range_replicas = []
|
||||
def out_of_range_fn(worker_index, device):
|
||||
@ -693,7 +695,8 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
||||
results.append(result)
|
||||
replicas = results
|
||||
|
||||
return distribute_utils.regroup(replicas)
|
||||
return _create_per_replica(replicas, self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
|
||||
|
||||
class DistributedIteratorV1(DistributedIteratorBase):
|
||||
@ -893,11 +896,25 @@ class DistributedIterator(DistributedIteratorBase,
|
||||
|
||||
@property
|
||||
def element_spec(self):
|
||||
# When partial batch handling is enabled, always set the batch dimension to
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegent way to handle
|
||||
# retracing and collectives.
|
||||
if (self._enable_get_next_as_optional and
|
||||
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
return nest.map_structure(
|
||||
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
|
||||
return self._element_spec
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return DistributedIteratorSpec(self._input_workers, self.element_spec,
|
||||
# Note that we use actual element_spec to create DistributedIteratorSpec,
|
||||
# to be consistent with the underlying iterators' specs.
|
||||
# TODO(b/163362689): remove the comment after the bug if fixed.
|
||||
return DistributedIteratorSpec(self._input_workers, self._element_spec,
|
||||
self._strategy,
|
||||
self._enable_get_next_as_optional)
|
||||
|
||||
@ -1097,7 +1114,7 @@ class DistributedDataset(_IterableInput):
|
||||
worker_iterators,
|
||||
self._strategy,
|
||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
||||
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||
|
||||
# When async eager is enabled, sometimes the iterator may not finish
|
||||
# initialization before passing to a multi device function, add a sync point
|
||||
@ -1110,6 +1127,17 @@ class DistributedDataset(_IterableInput):
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this dataset."""
|
||||
# When partial batch handling is enabled, always set the batch dimension to
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegent way to handle
|
||||
# retracing and collectives.
|
||||
if (self._enable_get_next_as_optional and
|
||||
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
return nest.map_structure(
|
||||
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
|
||||
return self._element_spec
|
||||
|
||||
|
||||
@ -1279,6 +1307,17 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
||||
@property
|
||||
def element_spec(self):
|
||||
"""The type specification of an element of this dataset."""
|
||||
# When partial batch handling is enabled, always set the batch dimension to
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegent way to handle
|
||||
# retracing and collectives.
|
||||
if (self._enable_get_next_as_optional and
|
||||
self._strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
return nest.map_structure(
|
||||
_rebatch_as_dynamic, self._element_spec, expand_composites=False)
|
||||
return self._element_spec
|
||||
|
||||
|
||||
@ -1376,6 +1415,7 @@ class InputFunctionIterator(DistributedIteratorV1):
|
||||
|
||||
super(InputFunctionIterator, self).__init__(
|
||||
input_workers, iterators, strategy, enable_get_next_as_optional=False)
|
||||
self._enable_get_next_as_optional = False
|
||||
|
||||
|
||||
# TODO(anjalisridhar): This class will soon be removed and users should move
|
||||
@ -2065,13 +2105,14 @@ def _create_distributed_tensor_spec(strategy, tensor_spec):
|
||||
"""
|
||||
num_replicas = len(strategy.extended.worker_devices)
|
||||
|
||||
# If the number of devices used in the strategy is just 1 then we return
|
||||
# the tensor_spec as is.
|
||||
if num_replicas == 1:
|
||||
# For one device strategy that is not MultiWorkerMirroredStrategy, return the
|
||||
# tensor_spec as is, since we don't wrap the output with PerReplica in this
|
||||
# case.
|
||||
# TODO(b/166464552): remove after we always wrap for all strategies.
|
||||
if not _always_wrap(strategy):
|
||||
return tensor_spec
|
||||
|
||||
# If the number of devices is greater than 1 then we assume the input to
|
||||
# tf.function is a per replica type.
|
||||
# For other cases we assume the input to tf.function is a per replica type.
|
||||
def _get_value_per_replica(tensor_spec_per_input):
|
||||
value_specs = [tensor_spec_per_input for _ in range(num_replicas)]
|
||||
return values.PerReplicaSpec(*value_specs)
|
||||
@ -2109,3 +2150,70 @@ def _enable_get_next_as_optional(strategy, dataset):
|
||||
|
||||
return not _is_statically_shaped(
|
||||
dataset.element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
||||
|
||||
|
||||
def _create_per_replica(value_list, strategy, get_next_as_optional):
|
||||
"""Creates a PerReplica.
|
||||
|
||||
For strategies other than OneDeviceStrategy, it creates a PerReplica whose
|
||||
type spec is set to the element spec of the dataset. This helps avoid
|
||||
retracing for partial batches. Retracing is problematic for multi client when
|
||||
different client retraces different time, since retracing changes the
|
||||
collective keys in the tf.function, and causes mismatches among clients.
|
||||
|
||||
For single client strategies, this simply calls distribute_utils.regroup().
|
||||
|
||||
Args:
|
||||
value_list: a list of values, one for each replica.
|
||||
strategy: the `tf.distribute.Strategy`.
|
||||
get_next_as_optional: whether last partial batch handling is enabled.
|
||||
|
||||
Returns:
|
||||
a structure of PerReplica.
|
||||
|
||||
"""
|
||||
# TODO(b/166464552): always wrap for all one device strategies as well.
|
||||
always_wrap = _always_wrap(strategy)
|
||||
per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
|
||||
|
||||
# When partial batch handling is enabled, always set the batch dimension to
|
||||
# None, otherwise we just follow element_spec of the underlying dataset
|
||||
# (whose batch dimension may also be None). This is because with partial
|
||||
# batching handling we could always produce empty batches.
|
||||
#
|
||||
# TODO(b/163362689): avoid this once we have more elegent way to handle
|
||||
# retracing and collectives.
|
||||
if (get_next_as_optional and strategy.extended._in_multi_worker_mode()): # pylint: disable=protected-access
|
||||
# Use expand_composites=False since we don't want to expand PerReplica,
|
||||
# which is a CompositeTensor.
|
||||
flat_per_replicas = nest.flatten(per_replicas, expand_composites=False)
|
||||
flat_spec = [type_spec.type_spec_from_value(v) for v in flat_per_replicas]
|
||||
for per_replica, spec in zip(flat_per_replicas, flat_spec):
|
||||
per_replica._type_spec_override = _rebatch_as_dynamic(spec) # pylint: disable=protected-access
|
||||
per_replicas = nest.pack_sequence_as(per_replicas, flat_per_replicas)
|
||||
|
||||
return per_replicas
|
||||
|
||||
|
||||
def _always_wrap(strategy):
|
||||
"""Returns whether to always wrap the values in a DistributedValues."""
|
||||
return strategy.extended._in_multi_worker_mode() or len( # pylint: disable=protected-access
|
||||
strategy.extended.worker_devices) > 1
|
||||
|
||||
|
||||
def _rebatch_as_dynamic(per_replica_spec):
|
||||
"""Rebatch the spec to have a dynamic batch dimension."""
|
||||
assert isinstance(per_replica_spec, values.PerReplicaSpec), per_replica_spec
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def _rebatch(spec):
|
||||
# Rebatch if possible.
|
||||
try:
|
||||
return spec._unbatch()._batch(None)
|
||||
except ValueError:
|
||||
pass
|
||||
return spec
|
||||
|
||||
return values.PerReplicaSpec(
|
||||
*nest.map_structure(_rebatch, per_replica_spec._value_specs))
|
||||
# pylint: enable=protected-access
|
||||
|
@ -557,7 +557,7 @@ class DistributedIteratorTest(DistributedIteratorTestBase,
|
||||
|
||||
iterator = iter(dist_dataset)
|
||||
for i, element in enumerate(iterator):
|
||||
self.assertEqual(i, element.numpy())
|
||||
self.assertAllEqual(distribution.experimental_local_results(element), [i])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
|
@ -18,15 +18,18 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute import test_util
|
||||
from tensorflow.python.distribute import tpu_strategy
|
||||
from tensorflow.python.distribute import values
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -37,6 +40,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
@ -116,14 +120,17 @@ class DistributedIteratorTest(test.TestCase,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
enable_get_next_as_optional=[True, False],
|
||||
drop_remainder=[True, False],
|
||||
tf_api_version=2,
|
||||
))
|
||||
def testDoesNotTriggerFunctionTracing(self, input_type, distribution,
|
||||
enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator CompositeTensor support is only "
|
||||
"present in TF 2.0 only.")
|
||||
|
||||
enable_get_next_as_optional,
|
||||
drop_remainder):
|
||||
trace_count = [0]
|
||||
|
||||
@def_function.function
|
||||
@ -135,7 +142,8 @@ class DistributedIteratorTest(test.TestCase,
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
dataset = dataset_ops.DatasetV2.range(10).batch(2)
|
||||
dataset = dataset_ops.DatasetV2.range(10).batch(
|
||||
2, drop_remainder=drop_remainder)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
@ -161,27 +169,79 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
strategy_combinations.mirrored_strategy_with_one_cpu,
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
input_type=["dataset", "dataset_fn"],
|
||||
tf_api_version=2,
|
||||
enable_get_next_as_optional=[True, False],
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testInputSignatureForPerReplicaValues(self, distribution, input_type):
|
||||
def dataset_fn(ctx):
|
||||
del ctx # unused
|
||||
return dataset_ops.DatasetV2.from_tensor_slices(
|
||||
np.ones([10, 12]).astype(np.float32)).batch(4)
|
||||
def testInputSignatureForPerReplicaValues(self, distribution,
|
||||
enable_get_next_as_optional,
|
||||
drop_remainder):
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
ds = dataset_ops.DatasetV2.from_tensor_slices(
|
||||
np.ones([9, 12]).astype(np.float32)).batch(
|
||||
4, drop_remainder=drop_remainder)
|
||||
ds = distribution.experimental_distribute_dataset(ds)
|
||||
_check_type_spec_structure(iter(ds))
|
||||
element_spec = ds.element_spec
|
||||
iter_element_spec = iter(ds).element_spec
|
||||
nest.assert_same_structure(element_spec, iter_element_spec)
|
||||
self.assertAllEqual(
|
||||
nest.flatten(element_spec), nest.flatten(iter_element_spec))
|
||||
|
||||
if input_type == "dataset":
|
||||
ds = distribution.experimental_distribute_dataset(
|
||||
dataset_fn(distribute_lib.InputContext()))
|
||||
type_spec = ds.element_spec
|
||||
else:
|
||||
ds = distribution.distribute_datasets_from_function(dataset_fn)
|
||||
iterator = iter(ds)
|
||||
_check_type_spec_structure(iterator)
|
||||
type_spec = iterator.element_spec
|
||||
@def_function.function(input_signature=[element_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
@def_function.function(input_signature=[type_spec])
|
||||
for x in ds:
|
||||
process_inputs(x)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
enable_get_next_as_optional=[True, False],
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromFunctionInputSignatureForPerReplicaValues(
|
||||
self, distribution, enable_get_next_as_optional, drop_remainder):
|
||||
# Create files that produce partial/empty batches at different batch. Note
|
||||
# that some worker will get empty batches even when drop_remainder=True.
|
||||
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||
_create_text_file(fname1, 5)
|
||||
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||
_create_text_file(fname2, 9)
|
||||
|
||||
def dataset_fn(input_context):
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
|
||||
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
return readers.TextLineDatasetV2(dataset).map(
|
||||
string_ops.string_to_number).batch(
|
||||
input_context.get_per_replica_batch_size(4),
|
||||
drop_remainder=drop_remainder)
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = (
|
||||
enable_get_next_as_optional)
|
||||
ds = distribution.experimental_distribute_datasets_from_function(dataset_fn)
|
||||
_check_type_spec_structure(iter(ds))
|
||||
element_spec = ds.element_spec
|
||||
iter_element_spec = iter(ds).element_spec
|
||||
nest.assert_same_structure(element_spec, iter_element_spec)
|
||||
self.assertAllEqual(
|
||||
nest.flatten(element_spec), nest.flatten(iter_element_spec))
|
||||
|
||||
@def_function.function(input_signature=[element_spec])
|
||||
def process_inputs(inputs):
|
||||
distribution.run(lambda inputs: inputs, args=(inputs,))
|
||||
|
||||
@ -247,6 +307,149 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertEqual(spec1, spec1.most_specific_compatible_type(spec2))
|
||||
self.assertEqual(spec1, spec2.most_specific_compatible_type(spec1))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromDatasetDoesNotTriggerFunctionTracing(self, distribution,
|
||||
drop_remainder):
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
# Total dataset size 5 allows us to have full batches, partial batches and
|
||||
# empty batches.
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3))).batch(
|
||||
4, drop_remainder=drop_remainder)
|
||||
dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromDatasetFileShardingDoesNotTriggerFunctionTracing(
|
||||
self, distribution, drop_remainder):
|
||||
# Create files that produce partial/empty batches at different batch.
|
||||
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||
_create_text_file(fname1, 5)
|
||||
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||
_create_text_file(fname2, 9)
|
||||
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
dataset = readers.TextLineDatasetV2([fname1, fname2]).batch(
|
||||
4, drop_remainder=drop_remainder)
|
||||
dataset = distribution.experimental_distribute_dataset(dataset)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromFunctionDoesNotTriggerFunctionTracing(self, distribution,
|
||||
drop_remainder):
|
||||
|
||||
def dataset_fn(input_context):
|
||||
# Total dataset size 5 allows us to have full batches, partial batches and
|
||||
# empty batches.
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices(np.ones((5, 3)))
|
||||
dataset = dataset.batch(
|
||||
input_context.get_per_replica_batch_size(4),
|
||||
drop_remainder=drop_remainder)
|
||||
return dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
dataset = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
distribution=[
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
tf_api_version=2,
|
||||
drop_remainder=[True, False],
|
||||
))
|
||||
def testFromFunctionFileShardingDoesNotTriggerFunctionTracing(
|
||||
self, distribution, drop_remainder):
|
||||
# Create files that produce partial/empty batches at different batch.
|
||||
fname1 = os.path.join(self.get_temp_dir(), "1.txt")
|
||||
_create_text_file(fname1, 5)
|
||||
fname2 = os.path.join(self.get_temp_dir(), "2.txt")
|
||||
_create_text_file(fname2, 9)
|
||||
|
||||
def dataset_fn(input_context):
|
||||
dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2])
|
||||
dataset = dataset.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
return readers.TextLineDatasetV2(dataset).batch(
|
||||
input_context.get_per_replica_batch_size(4),
|
||||
drop_remainder=drop_remainder)
|
||||
|
||||
self.trace_count = 0
|
||||
|
||||
@def_function.function
|
||||
def f(v):
|
||||
del v
|
||||
self.trace_count += 1
|
||||
|
||||
distribution.extended.experimental_enable_get_next_as_optional = True
|
||||
dataset = distribution.experimental_distribute_datasets_from_function(
|
||||
dataset_fn)
|
||||
for v in iter(dataset):
|
||||
f(v)
|
||||
self.assertEqual(self.trace_count, 1)
|
||||
|
||||
|
||||
class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
parameterized.TestCase):
|
||||
@ -254,14 +457,14 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
tf_api_version=2,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTypeSpec(self, distribution, enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator has CompositeTensor support in "
|
||||
"TF 2.0 only.")
|
||||
ctx = distribute_lib.InputContext()
|
||||
batch_size = ctx.get_per_replica_batch_size(8)
|
||||
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
||||
@ -313,16 +516,16 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
tf_api_version=2,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator CompositeTensor support is only "
|
||||
"present in TF 2.0 only.")
|
||||
|
||||
ctx = distribute_lib.InputContext()
|
||||
batch_size = ctx.get_per_replica_batch_size(8)
|
||||
# Use 20 which isn't divisible by 8 to test partial batch behavior.
|
||||
@ -366,17 +569,17 @@ class RaggedTensorDistributedIteratorTest(test.TestCase,
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
mode=["eager"],
|
||||
tf_api_version=2,
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_cpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x1_gpu,
|
||||
strategy_combinations.multi_worker_mirrored_2x2_gpu,
|
||||
],
|
||||
enable_get_next_as_optional=[True, False]))
|
||||
def testDoesNotTriggerFunctionTracing(self, distribution,
|
||||
enable_get_next_as_optional):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("DistributedIterator CompositeTensor support is only "
|
||||
"present in TF 2.0 only.")
|
||||
|
||||
trace_count = [0]
|
||||
|
||||
@def_function.function
|
||||
@ -432,5 +635,11 @@ def _check_type_spec_structure(x):
|
||||
nest.assert_same_structure(x, x._type_spec, expand_composites=True)
|
||||
|
||||
|
||||
def _create_text_file(fname, num_lines):
|
||||
with open(fname, "w") as f:
|
||||
for i in range(num_lines):
|
||||
f.write("%d\n" % i)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
test_util.main()
|
||||
|
@ -362,8 +362,23 @@ class DistributedDelegate(DistributedValues):
|
||||
class PerReplica(DistributedValues, composite_tensor.CompositeTensor):
|
||||
"""Holds a map from replica to unsynchronized values."""
|
||||
|
||||
def __init__(self, values, type_spec_override=None):
|
||||
super(PerReplica, self).__init__(values)
|
||||
# Allow setting a type spec that can be different from the underlying
|
||||
# values. This allows us avoid retracing for PerReplica from full, partial
|
||||
# and empty batches. In a multi client setup, we need to avoid such
|
||||
# retracing otherwise the collectives may mismatch since we assign new
|
||||
# collective keys when retracing the function.
|
||||
#
|
||||
# TODO(b/166169298): remove after CrossDeviceOps is tracing safe.
|
||||
self._type_spec_override = type_spec_override
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
if self._type_spec_override is not None:
|
||||
# Return a deep copy in case the caller changes it, since _type_spec()
|
||||
# normally returns a temporary object.
|
||||
return copy.deepcopy(self._type_spec_override)
|
||||
return PerReplicaSpec(
|
||||
*(type_spec.type_spec_from_value(v) for v in self._values))
|
||||
|
||||
|
@ -341,7 +341,7 @@ distribute_py_test(
|
||||
full_precision = True,
|
||||
main = "distribute_strategy_test.py",
|
||||
python_version = "PY3",
|
||||
shard_count = 10,
|
||||
shard_count = 20,
|
||||
tags = [
|
||||
"multi_and_single_gpu",
|
||||
"no_rocm", # times out on ROCm
|
||||
|
Loading…
Reference in New Issue
Block a user