Add experimental_make_distributed_values_from_function method to distribution strategy.

PiperOrigin-RevId: 298445422
Change-Id: I907cf7808e5bedaf45adccf3b6355ccf219e4116
This commit is contained in:
Ken Franko 2020-03-02 14:20:14 -08:00 committed by TensorFlower Gardener
parent 5c1f1be66d
commit 3f35d6d8b0
17 changed files with 447 additions and 6 deletions

View File

@ -415,6 +415,70 @@ class InputContext(object):
self.input_pipeline_id, self.num_input_pipelines) self.input_pipeline_id, self.num_input_pipelines)
@tf_export("distribute.experimental.ValueContext", v1=[])
class ValueContext(object):
"""A class wrapping information needed by a distribute function.
This is a context class that is passed to the `value_fn` in
`strategy.experimental_distribute_values_from_function` and contains
information about the compute replicas. The `num_replicas_in_sync` and
`replica_id` can be used to customize the value on each replica.
Example usage:
1. Directly constructed.
>>> def value_fn(context):
... return context.replica_id_in_sync_group/context.num_replicas_in_sync
>>> context = tf.distribute.experimental.ValueContext(
... replica_id_in_sync_group=2, num_replicas_in_sync=4)
>>> per_replica_value = value_fn(context)
>>> per_replica_value
0.5
2. Passed in by `experimental_distribute_values_from_function`.
>>> strategy = tf.distribute.MirroredStrategy()
>>> def value_fn(value_context):
... return value_context.num_replicas_in_sync
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(1,)
"""
def __init__(self,
replica_id_in_sync_group=0,
num_replicas_in_sync=1):
"""Initializes an ValueContext object.
Args:
replica_id_in_sync_group: the current replica_id, should be an int in
[0,`num_replicas_in_sync`).
num_replicas_in_sync: the number of replicas that are in sync.
"""
self._replica_id_in_sync_group = replica_id_in_sync_group
self._num_replicas_in_sync = num_replicas_in_sync
@property
def num_replicas_in_sync(self):
"""Returns the number of compute replicas in sync."""
return self._num_replicas_in_sync
@property
def replica_id_in_sync_group(self):
"""Returns the replica ID."""
return self._replica_id_in_sync_group
def __str__(self):
return (("tf.distribute.ValueContext(replica id {}, "
" total replicas in sync: ""{})")
.format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
@tf_export("distribute.RunOptions") @tf_export("distribute.RunOptions")
class RunOptions( class RunOptions(
collections.namedtuple("RunOptions", [ collections.namedtuple("RunOptions", [
@ -1212,6 +1276,81 @@ class Strategy(StrategyBase):
""" """
return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access
def experimental_distribute_values_from_function(self, value_fn):
"""Generates `tf.distribute.DistributedValues` from `value_fn`.
This function is to generate `tf.distribute.DistributedValues` to pass
into `experimental_run_v2`, `reduce`, or other methods that take
distributed values when not using datasets.
Args:
value_fn: The function to run to generate values. It is called for
each replica with `tf.distribute.ValueContext` as the sole argument. It
must return a Tensor or a type that can be converted to a Tensor.
Returns:
A `tf.distribute.DistributedValues` containing a value for each replica.
Example usage:
1. Return constant value per replica:
>>> strategy = tf.distribute.MirroredStrategy()
>>> def value_fn(ctx):
... return tf.constant(1.)
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,)
2. Distribute values in array based on replica_id:
>>> strategy = tf.distribute.MirroredStrategy()
>>> array_value = np.array([3., 2., 1.])
>>> def value_fn(ctx):
... return array_value[ctx.replica_id_in_sync_group]
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(3.0,)
3. Specify values using num_replicas_in_sync:
>>> strategy = tf.distribute.MirroredStrategy()
>>> def value_fn(ctx):
... return ctx.num_replicas_in_sync
>>> distributed_values = (
... strategy.experimental_distribute_values_from_function(
... value_fn))
>>> local_result = strategy.experimental_local_results(distributed_values)
>>> local_result
(1,)
4. Place values on devices and distribute:
```
strategy = tf.distribute.TPUStrategy()
worker_devices = strategy.extended.worker_devices
multiple_values = []
for i in range(strategy.num_replicas_in_sync):
with tf.device(worker_devices[i]):
multiple_values.append(tf.constant(1.0))
def value_fn(ctx):
return multiple_values[ctx.replica_id]
distributed_values = strategy.
experimental_distribute_values_from_function(
value_fn)
```
"""
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
value_fn)
# TF v1.x version has additional deprecated APIs # TF v1.x version has additional deprecated APIs
@tf_export(v1=["distribute.Strategy"]) @tf_export(v1=["distribute.Strategy"])
@ -1717,6 +1856,9 @@ class StrategyExtendedV2(object):
def _experimental_distribute_datasets_from_function(self, dataset_fn): def _experimental_distribute_datasets_from_function(self, dataset_fn):
raise NotImplementedError("must be implemented in descendants") raise NotImplementedError("must be implemented in descendants")
def _experimental_distribute_values_from_function(self, value_fn):
raise NotImplementedError("must be implemented in descendants")
def _reduce(self, reduce_op, value): def _reduce(self, reduce_op, value):
# Default implementation until we have an implementation for each strategy. # Default implementation until we have an implementation for each strategy.
return self._local_results( return self._local_results(
@ -2436,6 +2578,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
def _experimental_distribute_datasets_from_function(self, dataset_fn): def _experimental_distribute_datasets_from_function(self, dataset_fn):
return dataset_fn(InputContext()) return dataset_fn(InputContext())
def _experimental_distribute_values_from_function(self, value_fn):
return value_fn(ValueContext())
def _make_dataset_iterator(self, dataset): def _make_dataset_iterator(self, dataset):
return _DefaultDistributionExtended.DefaultInputIterator(dataset) return _DefaultDistributionExtended.DefaultInputIterator(dataset)

View File

@ -661,6 +661,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
input_contexts, input_contexts,
self._container_strategy()) self._container_strategy())
def _experimental_distribute_values_from_function(self, value_fn):
per_replica_values = []
for replica_id in range(self._num_replicas_in_sync):
per_replica_values.append(value_fn(
distribute_lib.ValueContext(replica_id,
self._num_replicas_in_sync)))
return values.regroup(per_replica_values, always_wrap=True)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values=None): initial_loop_values=None):

View File

@ -307,6 +307,12 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
[distribute_lib.InputContext()], [distribute_lib.InputContext()],
self._container_strategy()) self._container_strategy())
def _experimental_distribute_values_from_function(self, value_fn):
# TODO(b/137795644): This should return a PerReplica value but other
# methods like experimental_run_v2 in OneDeviceStrategy need to be modified
# to do the same.
return value_fn(distribute_lib.ValueContext())
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations, def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
initial_loop_values=None): initial_loop_values=None):

View File

@ -370,6 +370,12 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
[input_context], [input_context],
self._container_strategy()) self._container_strategy())
def _experimental_distribute_values_from_function(self, value_fn):
# TODO(b/137795644): Implement this method for ParameterServerStrategy if
# needed.
raise NotImplementedError("_experimental_distribute_values_from_function "
"not yet implemented in ParameterServerStrategy.")
def _broadcast_to(self, tensor, destinations): def _broadcast_to(self, tensor, destinations):
# This is both a fast path for Python constants, and a way to delay # This is both a fast path for Python constants, and a way to delay
# converting Python values to a tensor until we know what type it # converting Python values to a tensor until we know what type it

View File

@ -255,12 +255,19 @@ strategies_minus_tpu = [
mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
] ]
strategies_minus_default_and_tpu = [
one_device_strategy, one_device_strategy_gpu,
mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
]
tpu_strategies = [ tpu_strategies = [
tpu_strategy, # steps_per_run=2 tpu_strategy, # steps_per_run=2
tpu_strategy_one_step, tpu_strategy_one_step,
cloud_tpu_strategy, cloud_tpu_strategy,
] ]
all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies
all_strategies = strategies_minus_tpu + tpu_strategies all_strategies = strategies_minus_tpu + tpu_strategies
multidevice_strategies = [ multidevice_strategies = [

View File

@ -431,6 +431,14 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
input_contexts, input_contexts,
self._container_strategy()) self._container_strategy())
def _experimental_distribute_values_from_function(self, value_fn):
per_replica_values = []
for replica_id in range(self._num_replicas_in_sync):
per_replica_values.append(
value_fn(distribute_lib.ValueContext(replica_id,
self._num_replicas_in_sync)))
return values.regroup(per_replica_values, always_wrap=True)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have # TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
# a mechanism to infer the outputs of `fn`. Pending b/110550782. # a mechanism to infer the outputs of `fn`. Pending b/110550782.

View File

@ -1071,8 +1071,17 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
_tensor_conversion_sync_on_read) _tensor_conversion_sync_on_read)
def regroup(values, wrap_class=PerReplica): def regroup(values, wrap_class=PerReplica, always_wrap=False):
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.""" """Makes a nest per-replica into a nest of PerReplica/Mirrored values.
Args:
values: Values to regroup
wrap_class: Class that `values` be wrapped in.
always_wrap: Always wrap the `values` in `wrap_class` even if the values
are the same except for DistributeVariable.
Returns:
Wrapped `values`.
"""
v0 = values[0] v0 = values[0]
if isinstance(v0, list): if isinstance(v0, list):
@ -1124,14 +1133,16 @@ def regroup(values, wrap_class=PerReplica):
# devices), we want to return it. We check DistributedVariable # devices), we want to return it. We check DistributedVariable
# specifically since it can look like it has a # specifically since it can look like it has a
# _distributed_container member since its members do. # _distributed_container member since its members do.
if same_id and isinstance(v0, DistributedVariable):
return v0
# * If v0 is a member of a distributed variable, in which case # * If v0 is a member of a distributed variable, in which case
# hasattr(v0, "_distributed_container") is true, we want to # hasattr(v0, "_distributed_container") is true, we want to
# return the DistributedVariable that contains it using the # return the DistributedVariable that contains it using the
# _distributed_container logic below. This case can trigger # _distributed_container logic below. This case can trigger
# same_id when there is only one device. # same_id when there is only one device.
# * In any other situation, same_id means we return v0. # * In any other situation, same_id means we return v0 unless `always_wrap` is
if same_id and (isinstance(v0, DistributedVariable) or # true.
not hasattr(v0, "_distributed_container")): if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
return v0 return v0
# Detect the case where each device has a parallel component of the # Detect the case where each device has a parallel component of the

View File

@ -23,7 +23,10 @@ import itertools
import os import os
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import tf2
from tensorflow.python.distribute import combinations from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import distribute_lib from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
@ -39,12 +42,14 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import indexed_slices from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.saved_model.model_utils import mode_keys from tensorflow.python.saved_model.model_utils import mode_keys
@ -54,7 +59,7 @@ from tensorflow.python.training.tracking import util as trackable_utils
from tensorflow.python.util import nest from tensorflow.python.util import nest
class DistributedValuesTest(test.TestCase): class DistributedValuesTest(test.TestCase, parameterized.TestCase):
def testGetEager(self): def testGetEager(self):
one = constant_op.constant(1) one = constant_op.constant(1)
@ -73,6 +78,202 @@ class DistributedValuesTest(test.TestCase):
with distribute_lib.ReplicaContext(None, 1): with distribute_lib.ReplicaContext(None, 1):
self.assertEqual(two, v._get()) self.assertEqual(two, v._get())
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueFromTensor(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
single_value = constant_op.constant(1)
def value_fn(ctx):
del ctx
return single_value
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
self.assertAllEqual(
distribution.experimental_local_results(distributed_values),
constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
array_value = np.array([1., 2., 3.])
def value_fn(ctx):
del ctx
return array_value
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
self.assertLen(local_results, distribution.num_replicas_in_sync)
for result in local_results:
self.assertAllEqual(result, [1., 2., 3.])
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueTupleConstant(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
tuple_value = (1., 2., 3.)
def value_fn(ctx):
del ctx
return tuple_value
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
for result in local_results:
self.assertAllEqual(result, (1., 2., 3.))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
tuple_value = (1., 2., 3.)
def value_fn(ctx):
per_replica = []
for val in tuple_value:
per_replica.append(val * ctx.replica_id_in_sync_group)
return per_replica
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(
values.select_replica(i, distributed_values),
(1. * i, 2. * i, 3. * i))
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueSpareTensor(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
def value_fn(ctx):
del ctx
return sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(
sparse_ops.sparse_tensor_to_dense(local_results[i]),
[[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueExtractFromArray(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
multiple_values = range(distribution.num_replicas_in_sync)
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
local_results = distribution.experimental_local_results(distributed_values)
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(local_results[i], i)
@combinations.generate(
combinations.combine(
distribution=strategy_combinations.all_strategies_minus_default,
mode=["eager"]
))
def testMakeDistributedValueAndRun(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
@def_function.function
def run():
multiple_values = range(distribution.num_replicas_in_sync)
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
def computation(x):
return math_ops.square(x)
outputs = distribution.experimental_local_results(
distribution.experimental_run_v2(computation,
args=(distributed_values,)))
return outputs
local_results = run()
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(local_results[i], i**2)
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]
))
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
multiple_values = []
for i in range(distribution.num_replicas_in_sync):
multiple_values.append(constant_op.constant(1.0))
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(distributed_values._values[i].device,
"/job:localhost/replica:0/task:0/device:CPU:0")
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
strategy_combinations.tpu_strategy,
strategy_combinations.central_storage_strategy_with_two_gpus,
],
mode=["eager"]
))
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
if not tf2.enabled():
self.skipTest("Only V2 is supported.")
worker_devices = distribution.extended.worker_devices
multiple_values = []
for i in range(distribution.num_replicas_in_sync):
with ops.device(worker_devices[i]):
multiple_values.append(array_ops.identity(1.0))
def value_fn(ctx):
return multiple_values[ctx.replica_id_in_sync_group]
distributed_values = (
distribution.experimental_distribute_values_from_function(value_fn))
for i in range(distribution.num_replicas_in_sync):
self.assertAllEqual(distributed_values._values[i].device,
worker_devices[i])
class DistributedDelegateTest(test.TestCase): class DistributedDelegateTest(test.TestCase):

View File

@ -36,6 +36,10 @@ tf_class {
name: "experimental_distribute_datasets_from_function" name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "experimental_local_results" name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -36,6 +36,10 @@ tf_class {
name: "experimental_distribute_datasets_from_function" name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "experimental_local_results" name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -35,6 +35,10 @@ tf_class {
name: "experimental_distribute_datasets_from_function" name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "experimental_local_results" name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -36,6 +36,10 @@ tf_class {
name: "experimental_distribute_datasets_from_function" name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "experimental_local_results" name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -36,6 +36,10 @@ tf_class {
name: "experimental_distribute_datasets_from_function" name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "experimental_local_results" name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -36,6 +36,10 @@ tf_class {
name: "experimental_distribute_datasets_from_function" name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "experimental_local_results" name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -36,6 +36,10 @@ tf_class {
name: "experimental_distribute_datasets_from_function" name: "experimental_distribute_datasets_from_function"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "experimental_distribute_values_from_function"
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "experimental_local_results" name: "experimental_local_results"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"

View File

@ -0,0 +1,17 @@
path: "tensorflow.distribute.experimental.ValueContext"
tf_class {
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ValueContext\'>"
is_instance: "<type \'object\'>"
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "replica_id_in_sync_group"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'replica_id_in_sync_group\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'0\', \'1\'], "
}
}

View File

@ -24,4 +24,8 @@ tf_module {
name: "TPUStrategy" name: "TPUStrategy"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "ValueContext"
mtype: "<type \'type\'>"
}
} }