From 3f35d6d8b01aba989f36c5f7b7bfdce33dd97f2b Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Mon, 2 Mar 2020 14:20:14 -0800 Subject: [PATCH] Add experimental_make_distributed_values_from_function method to distribution strategy. PiperOrigin-RevId: 298445422 Change-Id: I907cf7808e5bedaf45adccf3b6355ccf219e4116 --- .../python/distribute/distribute_lib.py | 145 +++++++++++++ .../python/distribute/mirrored_strategy.py | 8 + .../python/distribute/one_device_strategy.py | 6 + .../distribute/parameter_server_strategy.py | 6 + .../distribute/strategy_combinations.py | 7 + tensorflow/python/distribute/tpu_strategy.py | 8 + tensorflow/python/distribute/values.py | 21 +- tensorflow/python/distribute/values_test.py | 203 +++++++++++++++++- ...orflow.distribute.-mirrored-strategy.pbtxt | 4 + ...flow.distribute.-one-device-strategy.pbtxt | 4 + .../v2/tensorflow.distribute.-strategy.pbtxt | 4 + ...perimental.-central-storage-strategy.pbtxt | 4 + ...ntal.-multi-worker-mirrored-strategy.pbtxt | 4 + ...erimental.-parameter-server-strategy.pbtxt | 4 + ...tribute.experimental.-t-p-u-strategy.pbtxt | 4 + ...stribute.experimental.-value-context.pbtxt | 17 ++ .../tensorflow.distribute.experimental.pbtxt | 4 + 17 files changed, 447 insertions(+), 6 deletions(-) create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-value-context.pbtxt diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 1d1e44f97c9..73ea4abf081 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -415,6 +415,70 @@ class InputContext(object): self.input_pipeline_id, self.num_input_pipelines) +@tf_export("distribute.experimental.ValueContext", v1=[]) +class ValueContext(object): + """A class wrapping information needed by a distribute function. + + This is a context class that is passed to the `value_fn` in + `strategy.experimental_distribute_values_from_function` and contains + information about the compute replicas. The `num_replicas_in_sync` and + `replica_id` can be used to customize the value on each replica. + + Example usage: + + 1. Directly constructed. + + >>> def value_fn(context): + ... return context.replica_id_in_sync_group/context.num_replicas_in_sync + >>> context = tf.distribute.experimental.ValueContext( + ... replica_id_in_sync_group=2, num_replicas_in_sync=4) + >>> per_replica_value = value_fn(context) + >>> per_replica_value + 0.5 + + 2. Passed in by `experimental_distribute_values_from_function`. + + >>> strategy = tf.distribute.MirroredStrategy() + >>> def value_fn(value_context): + ... return value_context.num_replicas_in_sync + >>> distributed_values = ( + ... strategy.experimental_distribute_values_from_function( + ... value_fn)) + >>> local_result = strategy.experimental_local_results(distributed_values) + >>> local_result + (1,) + + """ + + def __init__(self, + replica_id_in_sync_group=0, + num_replicas_in_sync=1): + """Initializes an ValueContext object. + + Args: + replica_id_in_sync_group: the current replica_id, should be an int in + [0,`num_replicas_in_sync`). + num_replicas_in_sync: the number of replicas that are in sync. + """ + self._replica_id_in_sync_group = replica_id_in_sync_group + self._num_replicas_in_sync = num_replicas_in_sync + + @property + def num_replicas_in_sync(self): + """Returns the number of compute replicas in sync.""" + return self._num_replicas_in_sync + + @property + def replica_id_in_sync_group(self): + """Returns the replica ID.""" + return self._replica_id_in_sync_group + + def __str__(self): + return (("tf.distribute.ValueContext(replica id {}, " + " total replicas in sync: ""{})") + .format(self.replica_id_in_sync_group, self.num_replicas_in_sync)) + + @tf_export("distribute.RunOptions") class RunOptions( collections.namedtuple("RunOptions", [ @@ -1212,6 +1276,81 @@ class Strategy(StrategyBase): """ return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access + def experimental_distribute_values_from_function(self, value_fn): + """Generates `tf.distribute.DistributedValues` from `value_fn`. + + This function is to generate `tf.distribute.DistributedValues` to pass + into `experimental_run_v2`, `reduce`, or other methods that take + distributed values when not using datasets. + + Args: + value_fn: The function to run to generate values. It is called for + each replica with `tf.distribute.ValueContext` as the sole argument. It + must return a Tensor or a type that can be converted to a Tensor. + Returns: + A `tf.distribute.DistributedValues` containing a value for each replica. + + Example usage: + + 1. Return constant value per replica: + + >>> strategy = tf.distribute.MirroredStrategy() + >>> def value_fn(ctx): + ... return tf.constant(1.) + >>> distributed_values = ( + ... strategy.experimental_distribute_values_from_function( + ... value_fn)) + >>> local_result = strategy.experimental_local_results(distributed_values) + >>> local_result + (,) + + 2. Distribute values in array based on replica_id: + + >>> strategy = tf.distribute.MirroredStrategy() + >>> array_value = np.array([3., 2., 1.]) + >>> def value_fn(ctx): + ... return array_value[ctx.replica_id_in_sync_group] + >>> distributed_values = ( + ... strategy.experimental_distribute_values_from_function( + ... value_fn)) + >>> local_result = strategy.experimental_local_results(distributed_values) + >>> local_result + (3.0,) + + 3. Specify values using num_replicas_in_sync: + + >>> strategy = tf.distribute.MirroredStrategy() + >>> def value_fn(ctx): + ... return ctx.num_replicas_in_sync + >>> distributed_values = ( + ... strategy.experimental_distribute_values_from_function( + ... value_fn)) + >>> local_result = strategy.experimental_local_results(distributed_values) + >>> local_result + (1,) + + 4. Place values on devices and distribute: + + ``` + strategy = tf.distribute.TPUStrategy() + worker_devices = strategy.extended.worker_devices + multiple_values = [] + for i in range(strategy.num_replicas_in_sync): + with tf.device(worker_devices[i]): + multiple_values.append(tf.constant(1.0)) + + def value_fn(ctx): + return multiple_values[ctx.replica_id] + + distributed_values = strategy. + experimental_distribute_values_from_function( + value_fn) + ``` + + """ + return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access + value_fn) + # TF v1.x version has additional deprecated APIs @tf_export(v1=["distribute.Strategy"]) @@ -1717,6 +1856,9 @@ class StrategyExtendedV2(object): def _experimental_distribute_datasets_from_function(self, dataset_fn): raise NotImplementedError("must be implemented in descendants") + def _experimental_distribute_values_from_function(self, value_fn): + raise NotImplementedError("must be implemented in descendants") + def _reduce(self, reduce_op, value): # Default implementation until we have an implementation for each strategy. return self._local_results( @@ -2436,6 +2578,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1): def _experimental_distribute_datasets_from_function(self, dataset_fn): return dataset_fn(InputContext()) + def _experimental_distribute_values_from_function(self, value_fn): + return value_fn(ValueContext()) + def _make_dataset_iterator(self, dataset): return _DefaultDistributionExtended.DefaultInputIterator(dataset) diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index baa6e1ac76e..6da586c251b 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -661,6 +661,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): input_contexts, self._container_strategy()) + def _experimental_distribute_values_from_function(self, value_fn): + per_replica_values = [] + for replica_id in range(self._num_replicas_in_sync): + per_replica_values.append(value_fn( + distribute_lib.ValueContext(replica_id, + self._num_replicas_in_sync))) + return values.regroup(per_replica_values, always_wrap=True) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index 5a6973a699b..6fa7cb27539 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -307,6 +307,12 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): [distribute_lib.InputContext()], self._container_strategy()) + def _experimental_distribute_values_from_function(self, value_fn): + # TODO(b/137795644): This should return a PerReplica value but other + # methods like experimental_run_v2 in OneDeviceStrategy need to be modified + # to do the same. + return value_fn(distribute_lib.ValueContext()) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, initial_loop_values=None): diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 2c2b39b4335..7099e4a6390 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -370,6 +370,12 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): [input_context], self._container_strategy()) + def _experimental_distribute_values_from_function(self, value_fn): + # TODO(b/137795644): Implement this method for ParameterServerStrategy if + # needed. + raise NotImplementedError("_experimental_distribute_values_from_function " + "not yet implemented in ParameterServerStrategy.") + def _broadcast_to(self, tensor, destinations): # This is both a fast path for Python constants, and a way to delay # converting Python values to a tensor until we know what type it diff --git a/tensorflow/python/distribute/strategy_combinations.py b/tensorflow/python/distribute/strategy_combinations.py index 12a2b48d236..72e4324d0b6 100644 --- a/tensorflow/python/distribute/strategy_combinations.py +++ b/tensorflow/python/distribute/strategy_combinations.py @@ -255,12 +255,19 @@ strategies_minus_tpu = [ mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus ] +strategies_minus_default_and_tpu = [ + one_device_strategy, one_device_strategy_gpu, + mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus +] + tpu_strategies = [ tpu_strategy, # steps_per_run=2 tpu_strategy_one_step, cloud_tpu_strategy, ] +all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies + all_strategies = strategies_minus_tpu + tpu_strategies multidevice_strategies = [ diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index 7176a2e2dc9..1d1c658d5e0 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -431,6 +431,14 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): input_contexts, self._container_strategy()) + def _experimental_distribute_values_from_function(self, value_fn): + per_replica_values = [] + for replica_id in range(self._num_replicas_in_sync): + per_replica_values.append( + value_fn(distribute_lib.ValueContext(replica_id, + self._num_replicas_in_sync))) + return values.regroup(per_replica_values, always_wrap=True) + # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # a mechanism to infer the outputs of `fn`. Pending b/110550782. diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index d0350657638..386a6b9790c 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -1071,8 +1071,17 @@ ops.register_tensor_conversion_function(SyncOnReadVariable, _tensor_conversion_sync_on_read) -def regroup(values, wrap_class=PerReplica): - """Makes a nest per-replica into a nest of PerReplica/Mirrored values.""" +def regroup(values, wrap_class=PerReplica, always_wrap=False): + """Makes a nest per-replica into a nest of PerReplica/Mirrored values. + + Args: + values: Values to regroup + wrap_class: Class that `values` be wrapped in. + always_wrap: Always wrap the `values` in `wrap_class` even if the values + are the same except for DistributeVariable. + Returns: + Wrapped `values`. + """ v0 = values[0] if isinstance(v0, list): @@ -1124,14 +1133,16 @@ def regroup(values, wrap_class=PerReplica): # devices), we want to return it. We check DistributedVariable # specifically since it can look like it has a # _distributed_container member since its members do. + if same_id and isinstance(v0, DistributedVariable): + return v0 # * If v0 is a member of a distributed variable, in which case # hasattr(v0, "_distributed_container") is true, we want to # return the DistributedVariable that contains it using the # _distributed_container logic below. This case can trigger # same_id when there is only one device. - # * In any other situation, same_id means we return v0. - if same_id and (isinstance(v0, DistributedVariable) or - not hasattr(v0, "_distributed_container")): + # * In any other situation, same_id means we return v0 unless `always_wrap` is + # true. + if same_id and not always_wrap and not hasattr(v0, "_distributed_container"): return v0 # Detect the case where each device has a parallel component of the diff --git a/tensorflow/python/distribute/values_test.py b/tensorflow/python/distribute/values_test.py index 6ed17696843..1035df489c2 100644 --- a/tensorflow/python/distribute/values_test.py +++ b/tensorflow/python/distribute/values_test.py @@ -23,7 +23,10 @@ import itertools import os from absl.testing import parameterized +import numpy as np + from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import tf2 from tensorflow.python.distribute import combinations from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribution_strategy_context @@ -39,12 +42,14 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.saved_model.model_utils import mode_keys @@ -54,7 +59,7 @@ from tensorflow.python.training.tracking import util as trackable_utils from tensorflow.python.util import nest -class DistributedValuesTest(test.TestCase): +class DistributedValuesTest(test.TestCase, parameterized.TestCase): def testGetEager(self): one = constant_op.constant(1) @@ -73,6 +78,202 @@ class DistributedValuesTest(test.TestCase): with distribute_lib.ReplicaContext(None, 1): self.assertEqual(two, v._get()) + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies_minus_default, + mode=["eager"] + )) + def testMakeDistributedValueFromTensor(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + single_value = constant_op.constant(1) + def value_fn(ctx): + del ctx + return single_value + + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + self.assertAllEqual( + distribution.experimental_local_results(distributed_values), + constant_op.constant(1., shape=(distribution.num_replicas_in_sync))) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies_minus_default, + mode=["eager"] + )) + def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + array_value = np.array([1., 2., 3.]) + def value_fn(ctx): + del ctx + return array_value + + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + local_results = distribution.experimental_local_results(distributed_values) + self.assertLen(local_results, distribution.num_replicas_in_sync) + for result in local_results: + self.assertAllEqual(result, [1., 2., 3.]) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies_minus_default, + mode=["eager"] + )) + def testMakeDistributedValueTupleConstant(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + tuple_value = (1., 2., 3.) + def value_fn(ctx): + del ctx + return tuple_value + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + local_results = distribution.experimental_local_results(distributed_values) + for result in local_results: + self.assertAllEqual(result, (1., 2., 3.)) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies_minus_default, + mode=["eager"] + )) + def testMakeDistributedValueNestedStructurePerReplica(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + tuple_value = (1., 2., 3.) + def value_fn(ctx): + per_replica = [] + for val in tuple_value: + per_replica.append(val * ctx.replica_id_in_sync_group) + return per_replica + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + for i in range(distribution.num_replicas_in_sync): + self.assertAllEqual( + values.select_replica(i, distributed_values), + (1. * i, 2. * i, 3. * i)) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies_minus_default, + mode=["eager"] + )) + def testMakeDistributedValueSpareTensor(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + def value_fn(ctx): + del ctx + return sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4]) + + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + local_results = distribution.experimental_local_results(distributed_values) + for i in range(distribution.num_replicas_in_sync): + self.assertAllEqual( + sparse_ops.sparse_tensor_to_dense(local_results[i]), + [[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]]) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies_minus_default, + mode=["eager"] + )) + def testMakeDistributedValueExtractFromArray(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + multiple_values = range(distribution.num_replicas_in_sync) + def value_fn(ctx): + return multiple_values[ctx.replica_id_in_sync_group] + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + local_results = distribution.experimental_local_results(distributed_values) + for i in range(distribution.num_replicas_in_sync): + self.assertAllEqual(local_results[i], i) + + @combinations.generate( + combinations.combine( + distribution=strategy_combinations.all_strategies_minus_default, + mode=["eager"] + )) + def testMakeDistributedValueAndRun(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + + @def_function.function + def run(): + multiple_values = range(distribution.num_replicas_in_sync) + def value_fn(ctx): + return multiple_values[ctx.replica_id_in_sync_group] + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + + def computation(x): + return math_ops.square(x) + + outputs = distribution.experimental_local_results( + distribution.experimental_run_v2(computation, + args=(distributed_values,))) + return outputs + + local_results = run() + + for i in range(distribution.num_replicas_in_sync): + self.assertAllEqual(local_results[i], i**2) + + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + mode=["eager"] + )) + def testMakeDistributedValueDefaultDevicePlacement(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + multiple_values = [] + for i in range(distribution.num_replicas_in_sync): + multiple_values.append(constant_op.constant(1.0)) + + def value_fn(ctx): + return multiple_values[ctx.replica_id_in_sync_group] + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + for i in range(distribution.num_replicas_in_sync): + self.assertAllEqual(distributed_values._values[i].device, + "/job:localhost/replica:0/task:0/device:CPU:0") + + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + mode=["eager"] + )) + def testMakeDistributedValueExplicitDevicePlacement(self, distribution): + if not tf2.enabled(): + self.skipTest("Only V2 is supported.") + worker_devices = distribution.extended.worker_devices + multiple_values = [] + for i in range(distribution.num_replicas_in_sync): + with ops.device(worker_devices[i]): + multiple_values.append(array_ops.identity(1.0)) + + def value_fn(ctx): + return multiple_values[ctx.replica_id_in_sync_group] + distributed_values = ( + distribution.experimental_distribute_values_from_function(value_fn)) + for i in range(distribution.num_replicas_in_sync): + self.assertAllEqual(distributed_values._values[i].device, + worker_devices[i]) + class DistributedDelegateTest(test.TestCase): diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt index 851d2eb56cf..34e6be4f6f1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-mirrored-strategy.pbtxt @@ -36,6 +36,10 @@ tf_class { name: "experimental_distribute_datasets_from_function" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt index e9033f33e83..35718ae1773 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-one-device-strategy.pbtxt @@ -36,6 +36,10 @@ tf_class { name: "experimental_distribute_datasets_from_function" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt index 3d8f48a1abc..1cdeab10f43 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -35,6 +35,10 @@ tf_class { name: "experimental_distribute_datasets_from_function" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt index c8242fa7b01..341f138dbb7 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-central-storage-strategy.pbtxt @@ -36,6 +36,10 @@ tf_class { name: "experimental_distribute_datasets_from_function" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt index c4019a1c72b..37e437d226f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-multi-worker-mirrored-strategy.pbtxt @@ -36,6 +36,10 @@ tf_class { name: "experimental_distribute_datasets_from_function" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt index 1c66eddfaa5..36cbca468eb 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-parameter-server-strategy.pbtxt @@ -36,6 +36,10 @@ tf_class { name: "experimental_distribute_datasets_from_function" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt index ffa50a638c2..422f65e8eb3 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-t-p-u-strategy.pbtxt @@ -36,6 +36,10 @@ tf_class { name: "experimental_distribute_datasets_from_function" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "experimental_distribute_values_from_function" + argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "experimental_local_results" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-value-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-value-context.pbtxt new file mode 100644 index 00000000000..d0c52a27a58 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.-value-context.pbtxt @@ -0,0 +1,17 @@ +path: "tensorflow.distribute.experimental.ValueContext" +tf_class { + is_instance: "" + is_instance: "" + member { + name: "num_replicas_in_sync" + mtype: "" + } + member { + name: "replica_id_in_sync_group" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'replica_id_in_sync_group\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'0\', \'1\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt index 9247db37925..06151eee4b4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.experimental.pbtxt @@ -24,4 +24,8 @@ tf_module { name: "TPUStrategy" mtype: "" } + member { + name: "ValueContext" + mtype: "" + } }