Add experimental_make_distributed_values_from_function method to distribution strategy.
PiperOrigin-RevId: 298445422 Change-Id: I907cf7808e5bedaf45adccf3b6355ccf219e4116
This commit is contained in:
parent
5c1f1be66d
commit
3f35d6d8b0
@ -415,6 +415,70 @@ class InputContext(object):
|
|||||||
self.input_pipeline_id, self.num_input_pipelines)
|
self.input_pipeline_id, self.num_input_pipelines)
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.experimental.ValueContext", v1=[])
|
||||||
|
class ValueContext(object):
|
||||||
|
"""A class wrapping information needed by a distribute function.
|
||||||
|
|
||||||
|
This is a context class that is passed to the `value_fn` in
|
||||||
|
`strategy.experimental_distribute_values_from_function` and contains
|
||||||
|
information about the compute replicas. The `num_replicas_in_sync` and
|
||||||
|
`replica_id` can be used to customize the value on each replica.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
1. Directly constructed.
|
||||||
|
|
||||||
|
>>> def value_fn(context):
|
||||||
|
... return context.replica_id_in_sync_group/context.num_replicas_in_sync
|
||||||
|
>>> context = tf.distribute.experimental.ValueContext(
|
||||||
|
... replica_id_in_sync_group=2, num_replicas_in_sync=4)
|
||||||
|
>>> per_replica_value = value_fn(context)
|
||||||
|
>>> per_replica_value
|
||||||
|
0.5
|
||||||
|
|
||||||
|
2. Passed in by `experimental_distribute_values_from_function`.
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> def value_fn(value_context):
|
||||||
|
... return value_context.num_replicas_in_sync
|
||||||
|
>>> distributed_values = (
|
||||||
|
... strategy.experimental_distribute_values_from_function(
|
||||||
|
... value_fn))
|
||||||
|
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||||
|
>>> local_result
|
||||||
|
(1,)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
replica_id_in_sync_group=0,
|
||||||
|
num_replicas_in_sync=1):
|
||||||
|
"""Initializes an ValueContext object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
replica_id_in_sync_group: the current replica_id, should be an int in
|
||||||
|
[0,`num_replicas_in_sync`).
|
||||||
|
num_replicas_in_sync: the number of replicas that are in sync.
|
||||||
|
"""
|
||||||
|
self._replica_id_in_sync_group = replica_id_in_sync_group
|
||||||
|
self._num_replicas_in_sync = num_replicas_in_sync
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_replicas_in_sync(self):
|
||||||
|
"""Returns the number of compute replicas in sync."""
|
||||||
|
return self._num_replicas_in_sync
|
||||||
|
|
||||||
|
@property
|
||||||
|
def replica_id_in_sync_group(self):
|
||||||
|
"""Returns the replica ID."""
|
||||||
|
return self._replica_id_in_sync_group
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return (("tf.distribute.ValueContext(replica id {}, "
|
||||||
|
" total replicas in sync: ""{})")
|
||||||
|
.format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
|
||||||
|
|
||||||
|
|
||||||
@tf_export("distribute.RunOptions")
|
@tf_export("distribute.RunOptions")
|
||||||
class RunOptions(
|
class RunOptions(
|
||||||
collections.namedtuple("RunOptions", [
|
collections.namedtuple("RunOptions", [
|
||||||
@ -1212,6 +1276,81 @@ class Strategy(StrategyBase):
|
|||||||
"""
|
"""
|
||||||
return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access
|
return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def experimental_distribute_values_from_function(self, value_fn):
|
||||||
|
"""Generates `tf.distribute.DistributedValues` from `value_fn`.
|
||||||
|
|
||||||
|
This function is to generate `tf.distribute.DistributedValues` to pass
|
||||||
|
into `experimental_run_v2`, `reduce`, or other methods that take
|
||||||
|
distributed values when not using datasets.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value_fn: The function to run to generate values. It is called for
|
||||||
|
each replica with `tf.distribute.ValueContext` as the sole argument. It
|
||||||
|
must return a Tensor or a type that can be converted to a Tensor.
|
||||||
|
Returns:
|
||||||
|
A `tf.distribute.DistributedValues` containing a value for each replica.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
1. Return constant value per replica:
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> def value_fn(ctx):
|
||||||
|
... return tf.constant(1.)
|
||||||
|
>>> distributed_values = (
|
||||||
|
... strategy.experimental_distribute_values_from_function(
|
||||||
|
... value_fn))
|
||||||
|
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||||
|
>>> local_result
|
||||||
|
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,)
|
||||||
|
|
||||||
|
2. Distribute values in array based on replica_id:
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> array_value = np.array([3., 2., 1.])
|
||||||
|
>>> def value_fn(ctx):
|
||||||
|
... return array_value[ctx.replica_id_in_sync_group]
|
||||||
|
>>> distributed_values = (
|
||||||
|
... strategy.experimental_distribute_values_from_function(
|
||||||
|
... value_fn))
|
||||||
|
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||||
|
>>> local_result
|
||||||
|
(3.0,)
|
||||||
|
|
||||||
|
3. Specify values using num_replicas_in_sync:
|
||||||
|
|
||||||
|
>>> strategy = tf.distribute.MirroredStrategy()
|
||||||
|
>>> def value_fn(ctx):
|
||||||
|
... return ctx.num_replicas_in_sync
|
||||||
|
>>> distributed_values = (
|
||||||
|
... strategy.experimental_distribute_values_from_function(
|
||||||
|
... value_fn))
|
||||||
|
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||||
|
>>> local_result
|
||||||
|
(1,)
|
||||||
|
|
||||||
|
4. Place values on devices and distribute:
|
||||||
|
|
||||||
|
```
|
||||||
|
strategy = tf.distribute.TPUStrategy()
|
||||||
|
worker_devices = strategy.extended.worker_devices
|
||||||
|
multiple_values = []
|
||||||
|
for i in range(strategy.num_replicas_in_sync):
|
||||||
|
with tf.device(worker_devices[i]):
|
||||||
|
multiple_values.append(tf.constant(1.0))
|
||||||
|
|
||||||
|
def value_fn(ctx):
|
||||||
|
return multiple_values[ctx.replica_id]
|
||||||
|
|
||||||
|
distributed_values = strategy.
|
||||||
|
experimental_distribute_values_from_function(
|
||||||
|
value_fn)
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
|
||||||
|
value_fn)
|
||||||
|
|
||||||
|
|
||||||
# TF v1.x version has additional deprecated APIs
|
# TF v1.x version has additional deprecated APIs
|
||||||
@tf_export(v1=["distribute.Strategy"])
|
@tf_export(v1=["distribute.Strategy"])
|
||||||
@ -1717,6 +1856,9 @@ class StrategyExtendedV2(object):
|
|||||||
def _experimental_distribute_datasets_from_function(self, dataset_fn):
|
def _experimental_distribute_datasets_from_function(self, dataset_fn):
|
||||||
raise NotImplementedError("must be implemented in descendants")
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
|
def _experimental_distribute_values_from_function(self, value_fn):
|
||||||
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
def _reduce(self, reduce_op, value):
|
def _reduce(self, reduce_op, value):
|
||||||
# Default implementation until we have an implementation for each strategy.
|
# Default implementation until we have an implementation for each strategy.
|
||||||
return self._local_results(
|
return self._local_results(
|
||||||
@ -2436,6 +2578,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
|||||||
def _experimental_distribute_datasets_from_function(self, dataset_fn):
|
def _experimental_distribute_datasets_from_function(self, dataset_fn):
|
||||||
return dataset_fn(InputContext())
|
return dataset_fn(InputContext())
|
||||||
|
|
||||||
|
def _experimental_distribute_values_from_function(self, value_fn):
|
||||||
|
return value_fn(ValueContext())
|
||||||
|
|
||||||
def _make_dataset_iterator(self, dataset):
|
def _make_dataset_iterator(self, dataset):
|
||||||
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
|
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
|
||||||
|
|
||||||
|
@ -661,6 +661,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
input_contexts,
|
input_contexts,
|
||||||
self._container_strategy())
|
self._container_strategy())
|
||||||
|
|
||||||
|
def _experimental_distribute_values_from_function(self, value_fn):
|
||||||
|
per_replica_values = []
|
||||||
|
for replica_id in range(self._num_replicas_in_sync):
|
||||||
|
per_replica_values.append(value_fn(
|
||||||
|
distribute_lib.ValueContext(replica_id,
|
||||||
|
self._num_replicas_in_sync)))
|
||||||
|
return values.regroup(per_replica_values, always_wrap=True)
|
||||||
|
|
||||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||||
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
||||||
initial_loop_values=None):
|
initial_loop_values=None):
|
||||||
|
@ -307,6 +307,12 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
[distribute_lib.InputContext()],
|
[distribute_lib.InputContext()],
|
||||||
self._container_strategy())
|
self._container_strategy())
|
||||||
|
|
||||||
|
def _experimental_distribute_values_from_function(self, value_fn):
|
||||||
|
# TODO(b/137795644): This should return a PerReplica value but other
|
||||||
|
# methods like experimental_run_v2 in OneDeviceStrategy need to be modified
|
||||||
|
# to do the same.
|
||||||
|
return value_fn(distribute_lib.ValueContext())
|
||||||
|
|
||||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||||
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
||||||
initial_loop_values=None):
|
initial_loop_values=None):
|
||||||
|
@ -370,6 +370,12 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
[input_context],
|
[input_context],
|
||||||
self._container_strategy())
|
self._container_strategy())
|
||||||
|
|
||||||
|
def _experimental_distribute_values_from_function(self, value_fn):
|
||||||
|
# TODO(b/137795644): Implement this method for ParameterServerStrategy if
|
||||||
|
# needed.
|
||||||
|
raise NotImplementedError("_experimental_distribute_values_from_function "
|
||||||
|
"not yet implemented in ParameterServerStrategy.")
|
||||||
|
|
||||||
def _broadcast_to(self, tensor, destinations):
|
def _broadcast_to(self, tensor, destinations):
|
||||||
# This is both a fast path for Python constants, and a way to delay
|
# This is both a fast path for Python constants, and a way to delay
|
||||||
# converting Python values to a tensor until we know what type it
|
# converting Python values to a tensor until we know what type it
|
||||||
|
@ -255,12 +255,19 @@ strategies_minus_tpu = [
|
|||||||
mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
|
mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
|
||||||
]
|
]
|
||||||
|
|
||||||
|
strategies_minus_default_and_tpu = [
|
||||||
|
one_device_strategy, one_device_strategy_gpu,
|
||||||
|
mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
|
||||||
|
]
|
||||||
|
|
||||||
tpu_strategies = [
|
tpu_strategies = [
|
||||||
tpu_strategy, # steps_per_run=2
|
tpu_strategy, # steps_per_run=2
|
||||||
tpu_strategy_one_step,
|
tpu_strategy_one_step,
|
||||||
cloud_tpu_strategy,
|
cloud_tpu_strategy,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies
|
||||||
|
|
||||||
all_strategies = strategies_minus_tpu + tpu_strategies
|
all_strategies = strategies_minus_tpu + tpu_strategies
|
||||||
|
|
||||||
multidevice_strategies = [
|
multidevice_strategies = [
|
||||||
|
@ -431,6 +431,14 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
input_contexts,
|
input_contexts,
|
||||||
self._container_strategy())
|
self._container_strategy())
|
||||||
|
|
||||||
|
def _experimental_distribute_values_from_function(self, value_fn):
|
||||||
|
per_replica_values = []
|
||||||
|
for replica_id in range(self._num_replicas_in_sync):
|
||||||
|
per_replica_values.append(
|
||||||
|
value_fn(distribute_lib.ValueContext(replica_id,
|
||||||
|
self._num_replicas_in_sync)))
|
||||||
|
return values.regroup(per_replica_values, always_wrap=True)
|
||||||
|
|
||||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||||
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
|
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
|
||||||
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
|
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
|
||||||
|
@ -1071,8 +1071,17 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
|
|||||||
_tensor_conversion_sync_on_read)
|
_tensor_conversion_sync_on_read)
|
||||||
|
|
||||||
|
|
||||||
def regroup(values, wrap_class=PerReplica):
|
def regroup(values, wrap_class=PerReplica, always_wrap=False):
|
||||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
values: Values to regroup
|
||||||
|
wrap_class: Class that `values` be wrapped in.
|
||||||
|
always_wrap: Always wrap the `values` in `wrap_class` even if the values
|
||||||
|
are the same except for DistributeVariable.
|
||||||
|
Returns:
|
||||||
|
Wrapped `values`.
|
||||||
|
"""
|
||||||
v0 = values[0]
|
v0 = values[0]
|
||||||
|
|
||||||
if isinstance(v0, list):
|
if isinstance(v0, list):
|
||||||
@ -1124,14 +1133,16 @@ def regroup(values, wrap_class=PerReplica):
|
|||||||
# devices), we want to return it. We check DistributedVariable
|
# devices), we want to return it. We check DistributedVariable
|
||||||
# specifically since it can look like it has a
|
# specifically since it can look like it has a
|
||||||
# _distributed_container member since its members do.
|
# _distributed_container member since its members do.
|
||||||
|
if same_id and isinstance(v0, DistributedVariable):
|
||||||
|
return v0
|
||||||
# * If v0 is a member of a distributed variable, in which case
|
# * If v0 is a member of a distributed variable, in which case
|
||||||
# hasattr(v0, "_distributed_container") is true, we want to
|
# hasattr(v0, "_distributed_container") is true, we want to
|
||||||
# return the DistributedVariable that contains it using the
|
# return the DistributedVariable that contains it using the
|
||||||
# _distributed_container logic below. This case can trigger
|
# _distributed_container logic below. This case can trigger
|
||||||
# same_id when there is only one device.
|
# same_id when there is only one device.
|
||||||
# * In any other situation, same_id means we return v0.
|
# * In any other situation, same_id means we return v0 unless `always_wrap` is
|
||||||
if same_id and (isinstance(v0, DistributedVariable) or
|
# true.
|
||||||
not hasattr(v0, "_distributed_container")):
|
if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
|
||||||
return v0
|
return v0
|
||||||
|
|
||||||
# Detect the case where each device has a parallel component of the
|
# Detect the case where each device has a parallel component of the
|
||||||
|
@ -23,7 +23,10 @@ import itertools
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.distribute import combinations
|
from tensorflow.python.distribute import combinations
|
||||||
from tensorflow.python.distribute import distribute_lib
|
from tensorflow.python.distribute import distribute_lib
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
@ -39,12 +42,14 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import indexed_slices
|
from tensorflow.python.framework import indexed_slices
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables as variables_lib
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
from tensorflow.python.saved_model.model_utils import mode_keys
|
from tensorflow.python.saved_model.model_utils import mode_keys
|
||||||
@ -54,7 +59,7 @@ from tensorflow.python.training.tracking import util as trackable_utils
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
class DistributedValuesTest(test.TestCase):
|
class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testGetEager(self):
|
def testGetEager(self):
|
||||||
one = constant_op.constant(1)
|
one = constant_op.constant(1)
|
||||||
@ -73,6 +78,202 @@ class DistributedValuesTest(test.TestCase):
|
|||||||
with distribute_lib.ReplicaContext(None, 1):
|
with distribute_lib.ReplicaContext(None, 1):
|
||||||
self.assertEqual(two, v._get())
|
self.assertEqual(two, v._get())
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies_minus_default,
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueFromTensor(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
single_value = constant_op.constant(1)
|
||||||
|
def value_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
return single_value
|
||||||
|
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
self.assertAllEqual(
|
||||||
|
distribution.experimental_local_results(distributed_values),
|
||||||
|
constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies_minus_default,
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
array_value = np.array([1., 2., 3.])
|
||||||
|
def value_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
return array_value
|
||||||
|
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
local_results = distribution.experimental_local_results(distributed_values)
|
||||||
|
self.assertLen(local_results, distribution.num_replicas_in_sync)
|
||||||
|
for result in local_results:
|
||||||
|
self.assertAllEqual(result, [1., 2., 3.])
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies_minus_default,
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueTupleConstant(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
tuple_value = (1., 2., 3.)
|
||||||
|
def value_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
return tuple_value
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
local_results = distribution.experimental_local_results(distributed_values)
|
||||||
|
for result in local_results:
|
||||||
|
self.assertAllEqual(result, (1., 2., 3.))
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies_minus_default,
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
tuple_value = (1., 2., 3.)
|
||||||
|
def value_fn(ctx):
|
||||||
|
per_replica = []
|
||||||
|
for val in tuple_value:
|
||||||
|
per_replica.append(val * ctx.replica_id_in_sync_group)
|
||||||
|
return per_replica
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
self.assertAllEqual(
|
||||||
|
values.select_replica(i, distributed_values),
|
||||||
|
(1. * i, 2. * i, 3. * i))
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies_minus_default,
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueSpareTensor(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
def value_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
return sparse_tensor.SparseTensor(
|
||||||
|
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
|
||||||
|
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
local_results = distribution.experimental_local_results(distributed_values)
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
self.assertAllEqual(
|
||||||
|
sparse_ops.sparse_tensor_to_dense(local_results[i]),
|
||||||
|
[[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies_minus_default,
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueExtractFromArray(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
multiple_values = range(distribution.num_replicas_in_sync)
|
||||||
|
def value_fn(ctx):
|
||||||
|
return multiple_values[ctx.replica_id_in_sync_group]
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
local_results = distribution.experimental_local_results(distributed_values)
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
self.assertAllEqual(local_results[i], i)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=strategy_combinations.all_strategies_minus_default,
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueAndRun(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def run():
|
||||||
|
multiple_values = range(distribution.num_replicas_in_sync)
|
||||||
|
def value_fn(ctx):
|
||||||
|
return multiple_values[ctx.replica_id_in_sync_group]
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
|
||||||
|
def computation(x):
|
||||||
|
return math_ops.square(x)
|
||||||
|
|
||||||
|
outputs = distribution.experimental_local_results(
|
||||||
|
distribution.experimental_run_v2(computation,
|
||||||
|
args=(distributed_values,)))
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
local_results = run()
|
||||||
|
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
self.assertAllEqual(local_results[i], i**2)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
|
],
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
multiple_values = []
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
multiple_values.append(constant_op.constant(1.0))
|
||||||
|
|
||||||
|
def value_fn(ctx):
|
||||||
|
return multiple_values[ctx.replica_id_in_sync_group]
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
self.assertAllEqual(distributed_values._values[i].device,
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0")
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.combine(
|
||||||
|
distribution=[
|
||||||
|
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||||
|
strategy_combinations.tpu_strategy,
|
||||||
|
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||||
|
],
|
||||||
|
mode=["eager"]
|
||||||
|
))
|
||||||
|
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
|
||||||
|
if not tf2.enabled():
|
||||||
|
self.skipTest("Only V2 is supported.")
|
||||||
|
worker_devices = distribution.extended.worker_devices
|
||||||
|
multiple_values = []
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
with ops.device(worker_devices[i]):
|
||||||
|
multiple_values.append(array_ops.identity(1.0))
|
||||||
|
|
||||||
|
def value_fn(ctx):
|
||||||
|
return multiple_values[ctx.replica_id_in_sync_group]
|
||||||
|
distributed_values = (
|
||||||
|
distribution.experimental_distribute_values_from_function(value_fn))
|
||||||
|
for i in range(distribution.num_replicas_in_sync):
|
||||||
|
self.assertAllEqual(distributed_values._values[i].device,
|
||||||
|
worker_devices[i])
|
||||||
|
|
||||||
|
|
||||||
class DistributedDelegateTest(test.TestCase):
|
class DistributedDelegateTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -36,6 +36,10 @@ tf_class {
|
|||||||
name: "experimental_distribute_datasets_from_function"
|
name: "experimental_distribute_datasets_from_function"
|
||||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_local_results"
|
name: "experimental_local_results"
|
||||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -36,6 +36,10 @@ tf_class {
|
|||||||
name: "experimental_distribute_datasets_from_function"
|
name: "experimental_distribute_datasets_from_function"
|
||||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_local_results"
|
name: "experimental_local_results"
|
||||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -35,6 +35,10 @@ tf_class {
|
|||||||
name: "experimental_distribute_datasets_from_function"
|
name: "experimental_distribute_datasets_from_function"
|
||||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_local_results"
|
name: "experimental_local_results"
|
||||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -36,6 +36,10 @@ tf_class {
|
|||||||
name: "experimental_distribute_datasets_from_function"
|
name: "experimental_distribute_datasets_from_function"
|
||||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_local_results"
|
name: "experimental_local_results"
|
||||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -36,6 +36,10 @@ tf_class {
|
|||||||
name: "experimental_distribute_datasets_from_function"
|
name: "experimental_distribute_datasets_from_function"
|
||||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_local_results"
|
name: "experimental_local_results"
|
||||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -36,6 +36,10 @@ tf_class {
|
|||||||
name: "experimental_distribute_datasets_from_function"
|
name: "experimental_distribute_datasets_from_function"
|
||||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_local_results"
|
name: "experimental_local_results"
|
||||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -36,6 +36,10 @@ tf_class {
|
|||||||
name: "experimental_distribute_datasets_from_function"
|
name: "experimental_distribute_datasets_from_function"
|
||||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_distribute_values_from_function"
|
||||||
|
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "experimental_local_results"
|
name: "experimental_local_results"
|
||||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
path: "tensorflow.distribute.experimental.ValueContext"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ValueContext\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "replica_id_in_sync_group"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'replica_id_in_sync_group\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'0\', \'1\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -24,4 +24,8 @@ tf_module {
|
|||||||
name: "TPUStrategy"
|
name: "TPUStrategy"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "ValueContext"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user