Add experimental_make_distributed_values_from_function method to distribution strategy.
PiperOrigin-RevId: 298445422 Change-Id: I907cf7808e5bedaf45adccf3b6355ccf219e4116
This commit is contained in:
parent
5c1f1be66d
commit
3f35d6d8b0
@ -415,6 +415,70 @@ class InputContext(object):
|
||||
self.input_pipeline_id, self.num_input_pipelines)
|
||||
|
||||
|
||||
@tf_export("distribute.experimental.ValueContext", v1=[])
|
||||
class ValueContext(object):
|
||||
"""A class wrapping information needed by a distribute function.
|
||||
|
||||
This is a context class that is passed to the `value_fn` in
|
||||
`strategy.experimental_distribute_values_from_function` and contains
|
||||
information about the compute replicas. The `num_replicas_in_sync` and
|
||||
`replica_id` can be used to customize the value on each replica.
|
||||
|
||||
Example usage:
|
||||
|
||||
1. Directly constructed.
|
||||
|
||||
>>> def value_fn(context):
|
||||
... return context.replica_id_in_sync_group/context.num_replicas_in_sync
|
||||
>>> context = tf.distribute.experimental.ValueContext(
|
||||
... replica_id_in_sync_group=2, num_replicas_in_sync=4)
|
||||
>>> per_replica_value = value_fn(context)
|
||||
>>> per_replica_value
|
||||
0.5
|
||||
|
||||
2. Passed in by `experimental_distribute_values_from_function`.
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> def value_fn(value_context):
|
||||
... return value_context.num_replicas_in_sync
|
||||
>>> distributed_values = (
|
||||
... strategy.experimental_distribute_values_from_function(
|
||||
... value_fn))
|
||||
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||
>>> local_result
|
||||
(1,)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
replica_id_in_sync_group=0,
|
||||
num_replicas_in_sync=1):
|
||||
"""Initializes an ValueContext object.
|
||||
|
||||
Args:
|
||||
replica_id_in_sync_group: the current replica_id, should be an int in
|
||||
[0,`num_replicas_in_sync`).
|
||||
num_replicas_in_sync: the number of replicas that are in sync.
|
||||
"""
|
||||
self._replica_id_in_sync_group = replica_id_in_sync_group
|
||||
self._num_replicas_in_sync = num_replicas_in_sync
|
||||
|
||||
@property
|
||||
def num_replicas_in_sync(self):
|
||||
"""Returns the number of compute replicas in sync."""
|
||||
return self._num_replicas_in_sync
|
||||
|
||||
@property
|
||||
def replica_id_in_sync_group(self):
|
||||
"""Returns the replica ID."""
|
||||
return self._replica_id_in_sync_group
|
||||
|
||||
def __str__(self):
|
||||
return (("tf.distribute.ValueContext(replica id {}, "
|
||||
" total replicas in sync: ""{})")
|
||||
.format(self.replica_id_in_sync_group, self.num_replicas_in_sync))
|
||||
|
||||
|
||||
@tf_export("distribute.RunOptions")
|
||||
class RunOptions(
|
||||
collections.namedtuple("RunOptions", [
|
||||
@ -1212,6 +1276,81 @@ class Strategy(StrategyBase):
|
||||
"""
|
||||
return self._extended._experimental_replicate_to_logical_devices(tensor) # pylint: disable=protected-access
|
||||
|
||||
def experimental_distribute_values_from_function(self, value_fn):
|
||||
"""Generates `tf.distribute.DistributedValues` from `value_fn`.
|
||||
|
||||
This function is to generate `tf.distribute.DistributedValues` to pass
|
||||
into `experimental_run_v2`, `reduce`, or other methods that take
|
||||
distributed values when not using datasets.
|
||||
|
||||
Args:
|
||||
value_fn: The function to run to generate values. It is called for
|
||||
each replica with `tf.distribute.ValueContext` as the sole argument. It
|
||||
must return a Tensor or a type that can be converted to a Tensor.
|
||||
Returns:
|
||||
A `tf.distribute.DistributedValues` containing a value for each replica.
|
||||
|
||||
Example usage:
|
||||
|
||||
1. Return constant value per replica:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> def value_fn(ctx):
|
||||
... return tf.constant(1.)
|
||||
>>> distributed_values = (
|
||||
... strategy.experimental_distribute_values_from_function(
|
||||
... value_fn))
|
||||
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||
>>> local_result
|
||||
(<tf.Tensor: shape=(), dtype=float32, numpy=1.0>,)
|
||||
|
||||
2. Distribute values in array based on replica_id:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> array_value = np.array([3., 2., 1.])
|
||||
>>> def value_fn(ctx):
|
||||
... return array_value[ctx.replica_id_in_sync_group]
|
||||
>>> distributed_values = (
|
||||
... strategy.experimental_distribute_values_from_function(
|
||||
... value_fn))
|
||||
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||
>>> local_result
|
||||
(3.0,)
|
||||
|
||||
3. Specify values using num_replicas_in_sync:
|
||||
|
||||
>>> strategy = tf.distribute.MirroredStrategy()
|
||||
>>> def value_fn(ctx):
|
||||
... return ctx.num_replicas_in_sync
|
||||
>>> distributed_values = (
|
||||
... strategy.experimental_distribute_values_from_function(
|
||||
... value_fn))
|
||||
>>> local_result = strategy.experimental_local_results(distributed_values)
|
||||
>>> local_result
|
||||
(1,)
|
||||
|
||||
4. Place values on devices and distribute:
|
||||
|
||||
```
|
||||
strategy = tf.distribute.TPUStrategy()
|
||||
worker_devices = strategy.extended.worker_devices
|
||||
multiple_values = []
|
||||
for i in range(strategy.num_replicas_in_sync):
|
||||
with tf.device(worker_devices[i]):
|
||||
multiple_values.append(tf.constant(1.0))
|
||||
|
||||
def value_fn(ctx):
|
||||
return multiple_values[ctx.replica_id]
|
||||
|
||||
distributed_values = strategy.
|
||||
experimental_distribute_values_from_function(
|
||||
value_fn)
|
||||
```
|
||||
|
||||
"""
|
||||
return self._extended._experimental_distribute_values_from_function( # pylint: disable=protected-access
|
||||
value_fn)
|
||||
|
||||
|
||||
# TF v1.x version has additional deprecated APIs
|
||||
@tf_export(v1=["distribute.Strategy"])
|
||||
@ -1717,6 +1856,9 @@ class StrategyExtendedV2(object):
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
def _reduce(self, reduce_op, value):
|
||||
# Default implementation until we have an implementation for each strategy.
|
||||
return self._local_results(
|
||||
@ -2436,6 +2578,9 @@ class _DefaultDistributionExtended(StrategyExtendedV1):
|
||||
def _experimental_distribute_datasets_from_function(self, dataset_fn):
|
||||
return dataset_fn(InputContext())
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
return value_fn(ValueContext())
|
||||
|
||||
def _make_dataset_iterator(self, dataset):
|
||||
return _DefaultDistributionExtended.DefaultInputIterator(dataset)
|
||||
|
||||
|
@ -661,6 +661,14 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
input_contexts,
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
per_replica_values = []
|
||||
for replica_id in range(self._num_replicas_in_sync):
|
||||
per_replica_values.append(value_fn(
|
||||
distribute_lib.ValueContext(replica_id,
|
||||
self._num_replicas_in_sync)))
|
||||
return values.regroup(per_replica_values, always_wrap=True)
|
||||
|
||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
||||
initial_loop_values=None):
|
||||
|
@ -307,6 +307,12 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
[distribute_lib.InputContext()],
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
# TODO(b/137795644): This should return a PerReplica value but other
|
||||
# methods like experimental_run_v2 in OneDeviceStrategy need to be modified
|
||||
# to do the same.
|
||||
return value_fn(distribute_lib.ValueContext())
|
||||
|
||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||
def _experimental_run_steps_on_iterator(self, fn, iterator, iterations,
|
||||
initial_loop_values=None):
|
||||
|
@ -370,6 +370,12 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
[input_context],
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
# TODO(b/137795644): Implement this method for ParameterServerStrategy if
|
||||
# needed.
|
||||
raise NotImplementedError("_experimental_distribute_values_from_function "
|
||||
"not yet implemented in ParameterServerStrategy.")
|
||||
|
||||
def _broadcast_to(self, tensor, destinations):
|
||||
# This is both a fast path for Python constants, and a way to delay
|
||||
# converting Python values to a tensor until we know what type it
|
||||
|
@ -255,12 +255,19 @@ strategies_minus_tpu = [
|
||||
mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
|
||||
]
|
||||
|
||||
strategies_minus_default_and_tpu = [
|
||||
one_device_strategy, one_device_strategy_gpu,
|
||||
mirrored_strategy_with_gpu_and_cpu, mirrored_strategy_with_two_gpus
|
||||
]
|
||||
|
||||
tpu_strategies = [
|
||||
tpu_strategy, # steps_per_run=2
|
||||
tpu_strategy_one_step,
|
||||
cloud_tpu_strategy,
|
||||
]
|
||||
|
||||
all_strategies_minus_default = strategies_minus_default_and_tpu + tpu_strategies
|
||||
|
||||
all_strategies = strategies_minus_tpu + tpu_strategies
|
||||
|
||||
multidevice_strategies = [
|
||||
|
@ -431,6 +431,14 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
input_contexts,
|
||||
self._container_strategy())
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
per_replica_values = []
|
||||
for replica_id in range(self._num_replicas_in_sync):
|
||||
per_replica_values.append(
|
||||
value_fn(distribute_lib.ValueContext(replica_id,
|
||||
self._num_replicas_in_sync)))
|
||||
return values.regroup(per_replica_values, always_wrap=True)
|
||||
|
||||
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
|
||||
# TODO(sourabhbajaj): Remove the initial_loop_values parameter when we have
|
||||
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
|
||||
|
@ -1071,8 +1071,17 @@ ops.register_tensor_conversion_function(SyncOnReadVariable,
|
||||
_tensor_conversion_sync_on_read)
|
||||
|
||||
|
||||
def regroup(values, wrap_class=PerReplica):
|
||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values."""
|
||||
def regroup(values, wrap_class=PerReplica, always_wrap=False):
|
||||
"""Makes a nest per-replica into a nest of PerReplica/Mirrored values.
|
||||
|
||||
Args:
|
||||
values: Values to regroup
|
||||
wrap_class: Class that `values` be wrapped in.
|
||||
always_wrap: Always wrap the `values` in `wrap_class` even if the values
|
||||
are the same except for DistributeVariable.
|
||||
Returns:
|
||||
Wrapped `values`.
|
||||
"""
|
||||
v0 = values[0]
|
||||
|
||||
if isinstance(v0, list):
|
||||
@ -1124,14 +1133,16 @@ def regroup(values, wrap_class=PerReplica):
|
||||
# devices), we want to return it. We check DistributedVariable
|
||||
# specifically since it can look like it has a
|
||||
# _distributed_container member since its members do.
|
||||
if same_id and isinstance(v0, DistributedVariable):
|
||||
return v0
|
||||
# * If v0 is a member of a distributed variable, in which case
|
||||
# hasattr(v0, "_distributed_container") is true, we want to
|
||||
# return the DistributedVariable that contains it using the
|
||||
# _distributed_container logic below. This case can trigger
|
||||
# same_id when there is only one device.
|
||||
# * In any other situation, same_id means we return v0.
|
||||
if same_id and (isinstance(v0, DistributedVariable) or
|
||||
not hasattr(v0, "_distributed_container")):
|
||||
# * In any other situation, same_id means we return v0 unless `always_wrap` is
|
||||
# true.
|
||||
if same_id and not always_wrap and not hasattr(v0, "_distributed_container"):
|
||||
return v0
|
||||
|
||||
# Detect the case where each device has a parallel component of the
|
||||
|
@ -23,7 +23,10 @@ import itertools
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.distribute import combinations
|
||||
from tensorflow.python.distribute import distribute_lib
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
@ -39,12 +42,14 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import indexed_slices
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
from tensorflow.python.saved_model.model_utils import mode_keys
|
||||
@ -54,7 +59,7 @@ from tensorflow.python.training.tracking import util as trackable_utils
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class DistributedValuesTest(test.TestCase):
|
||||
class DistributedValuesTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testGetEager(self):
|
||||
one = constant_op.constant(1)
|
||||
@ -73,6 +78,202 @@ class DistributedValuesTest(test.TestCase):
|
||||
with distribute_lib.ReplicaContext(None, 1):
|
||||
self.assertEqual(two, v._get())
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies_minus_default,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueFromTensor(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
single_value = constant_op.constant(1)
|
||||
def value_fn(ctx):
|
||||
del ctx
|
||||
return single_value
|
||||
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
self.assertAllEqual(
|
||||
distribution.experimental_local_results(distributed_values),
|
||||
constant_op.constant(1., shape=(distribution.num_replicas_in_sync)))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies_minus_default,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueSingleNumpyArrayConstant(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
array_value = np.array([1., 2., 3.])
|
||||
def value_fn(ctx):
|
||||
del ctx
|
||||
return array_value
|
||||
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
local_results = distribution.experimental_local_results(distributed_values)
|
||||
self.assertLen(local_results, distribution.num_replicas_in_sync)
|
||||
for result in local_results:
|
||||
self.assertAllEqual(result, [1., 2., 3.])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies_minus_default,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueTupleConstant(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
tuple_value = (1., 2., 3.)
|
||||
def value_fn(ctx):
|
||||
del ctx
|
||||
return tuple_value
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
local_results = distribution.experimental_local_results(distributed_values)
|
||||
for result in local_results:
|
||||
self.assertAllEqual(result, (1., 2., 3.))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies_minus_default,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueNestedStructurePerReplica(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
tuple_value = (1., 2., 3.)
|
||||
def value_fn(ctx):
|
||||
per_replica = []
|
||||
for val in tuple_value:
|
||||
per_replica.append(val * ctx.replica_id_in_sync_group)
|
||||
return per_replica
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
self.assertAllEqual(
|
||||
values.select_replica(i, distributed_values),
|
||||
(1. * i, 2. * i, 3. * i))
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies_minus_default,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueSpareTensor(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
def value_fn(ctx):
|
||||
del ctx
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
|
||||
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
local_results = distribution.experimental_local_results(distributed_values)
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
self.assertAllEqual(
|
||||
sparse_ops.sparse_tensor_to_dense(local_results[i]),
|
||||
[[1, 0, 0, 0], [0, 0, 2, 0], [0, 0, 0, 0]])
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies_minus_default,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueExtractFromArray(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
multiple_values = range(distribution.num_replicas_in_sync)
|
||||
def value_fn(ctx):
|
||||
return multiple_values[ctx.replica_id_in_sync_group]
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
local_results = distribution.experimental_local_results(distributed_values)
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
self.assertAllEqual(local_results[i], i)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=strategy_combinations.all_strategies_minus_default,
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueAndRun(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
|
||||
@def_function.function
|
||||
def run():
|
||||
multiple_values = range(distribution.num_replicas_in_sync)
|
||||
def value_fn(ctx):
|
||||
return multiple_values[ctx.replica_id_in_sync_group]
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
|
||||
def computation(x):
|
||||
return math_ops.square(x)
|
||||
|
||||
outputs = distribution.experimental_local_results(
|
||||
distribution.experimental_run_v2(computation,
|
||||
args=(distributed_values,)))
|
||||
return outputs
|
||||
|
||||
local_results = run()
|
||||
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
self.assertAllEqual(local_results[i], i**2)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueDefaultDevicePlacement(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
multiple_values = []
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
multiple_values.append(constant_op.constant(1.0))
|
||||
|
||||
def value_fn(ctx):
|
||||
return multiple_values[ctx.replica_id_in_sync_group]
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
self.assertAllEqual(distributed_values._values[i].device,
|
||||
"/job:localhost/replica:0/task:0/device:CPU:0")
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
distribution=[
|
||||
strategy_combinations.mirrored_strategy_with_gpu_and_cpu,
|
||||
strategy_combinations.tpu_strategy,
|
||||
strategy_combinations.central_storage_strategy_with_two_gpus,
|
||||
],
|
||||
mode=["eager"]
|
||||
))
|
||||
def testMakeDistributedValueExplicitDevicePlacement(self, distribution):
|
||||
if not tf2.enabled():
|
||||
self.skipTest("Only V2 is supported.")
|
||||
worker_devices = distribution.extended.worker_devices
|
||||
multiple_values = []
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
with ops.device(worker_devices[i]):
|
||||
multiple_values.append(array_ops.identity(1.0))
|
||||
|
||||
def value_fn(ctx):
|
||||
return multiple_values[ctx.replica_id_in_sync_group]
|
||||
distributed_values = (
|
||||
distribution.experimental_distribute_values_from_function(value_fn))
|
||||
for i in range(distribution.num_replicas_in_sync):
|
||||
self.assertAllEqual(distributed_values._values[i].device,
|
||||
worker_devices[i])
|
||||
|
||||
|
||||
class DistributedDelegateTest(test.TestCase):
|
||||
|
||||
|
@ -36,6 +36,10 @@ tf_class {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -36,6 +36,10 @@ tf_class {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -35,6 +35,10 @@ tf_class {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -36,6 +36,10 @@ tf_class {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -36,6 +36,10 @@ tf_class {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -36,6 +36,10 @@ tf_class {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -36,6 +36,10 @@ tf_class {
|
||||
name: "experimental_distribute_datasets_from_function"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_distribute_values_from_function"
|
||||
argspec: "args=[\'self\', \'value_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_local_results"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -0,0 +1,17 @@
|
||||
path: "tensorflow.distribute.experimental.ValueContext"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.distribute.distribute_lib.ValueContext\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "replica_id_in_sync_group"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'replica_id_in_sync_group\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'0\', \'1\'], "
|
||||
}
|
||||
}
|
@ -24,4 +24,8 @@ tf_module {
|
||||
name: "TPUStrategy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "ValueContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user