Create tf.distribute namespace with the base classes and standard APIs

for distribution strategies. Strategy implementations will be in a
future change. Also rename DistributionStrategy to tf.distribute.Strategy,
and other changes to match.

RELNOTES: Expose tf.distribute.Strategy as the new name for
tf.contrib.distribute.DistributionStrategy.
PiperOrigin-RevId: 222097121
This commit is contained in:
A. Unique TensorFlower 2018-11-19 10:21:43 -08:00 committed by TensorFlower Gardener
parent 5c60fb7e9b
commit a4bcedd676
23 changed files with 837 additions and 126 deletions

View File

@ -216,7 +216,10 @@ py_library(
py_library(
name = "reduce_util",
srcs = ["reduce_util.py"],
deps = [],
deps = [
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
],
)
py_library(

View File

@ -21,9 +21,10 @@ from __future__ import print_function
import enum
from tensorflow.python.ops import variable_scope
from tensorflow.python.util.tf_export import tf_export
# TODO(priyag): Add this to tf.distribute namespace when it exists.
@tf_export("distribute.ReduceOp")
class ReduceOp(enum.Enum):
"""Indicates how a set of values should be reduced.

View File

@ -9,6 +9,7 @@ TENSORFLOW_API_INIT_FILES = [
"data/__init__.py",
"data/experimental/__init__.py",
"debugging/__init__.py",
"distribute/__init__.py",
"dtypes/__init__.py",
"errors/__init__.py",
"experimental/__init__.py",

View File

@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"data/__init__.py",
"data/experimental/__init__.py",
"debugging/__init__.py",
"distribute/__init__.py",
"distributions/__init__.py",
"dtypes/__init__.py",
"errors/__init__.py",

View File

@ -35,10 +35,11 @@ DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
_TENSORFLOW_DOC_SOURCES = {
'app': DocSource(docstring_module_name='platform.app'),
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
'compat': DocSource(docstring_module_name='util.compat'),
'distribute': DocSource(docstring_module_name='training.distribute'),
'distributions': DocSource(
docstring_module_name='ops.distributions.distributions'),
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
'errors': DocSource(docstring_module_name='framework.errors'),
'gfile': DocSource(docstring_module_name='platform.gfile'),
'graph_util': DocSource(docstring_module_name='framework.graph_util'),

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class DistributionStrategy, ReplicaContext, and supporting APIs."""
"""Library for running a computation across multiple devices."""
from __future__ import absolute_import
from __future__ import division
@ -38,19 +38,19 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
# ------------------------------------------------------------------------------
# Context tracking whether in a distribution.update() or .update_non_slot()
# call.
# Context tracking whether in a strategy.update() or .update_non_slot() call.
_update_device = threading.local()
def get_update_device():
"""Get the current device if in a `DistributionStrategy.update()` call."""
"""Get the current device if in a `tf.distribute.Strategy.update()` call."""
try:
return _update_device.current
except AttributeError:
@ -77,8 +77,9 @@ class UpdateContext(object):
# Public utility functions.
@tf_export("distribute.get_loss_reduction")
def get_loss_reduction():
"""Reduce op corresponding to the last loss reduction."""
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if loss_reduction == losses_impl.Reduction.SUM:
return reduce_util.ReduceOp.SUM
@ -95,25 +96,25 @@ def _require_cross_replica_context_extended(extended):
cross_replica = context.cross_replica_context
if cross_replica is not None and cross_replica.extended is extended:
return
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access
strategy = extended._container_strategy() # pylint: disable=protected-access
# We have an error to report, figure out the right message.
if context.distribution_strategy is not distribution_strategy:
_wrong_distribution_strategy_scope(distribution_strategy, context)
if context.distribution_strategy is not strategy:
_wrong_strategy_scope(strategy, context)
assert cross_replica is None
raise RuntimeError("Method requires being in cross-replica context, use "
"get_replica_context().merge_call()")
def _wrong_distribution_strategy_scope(distribution_strategy, context):
def _wrong_strategy_scope(strategy, context):
# Figure out the right error message.
if not distribution_strategy_context.has_distribution_strategy():
raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' %
(distribution_strategy,))
'Need to be inside "with strategy.scope()" for %s' %
(strategy,))
else:
raise RuntimeError(
"Mixing different DistributionStrategy objects: %s is not %s" %
(context.distribution_strategy, distribution_strategy))
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
(context.distribution_strategy, strategy))
def require_replica_context(replica_ctx):
@ -124,18 +125,18 @@ def require_replica_context(replica_ctx):
if context.replica_context is None:
raise RuntimeError("Need to be inside `call_for_each_replica()`")
if context.distribution_strategy is replica_ctx.distribution_strategy:
# Two different ReplicaContexts with the same DistributionStrategy.
# Two different ReplicaContexts with the same tf.distribute.Strategy.
raise RuntimeError("Mismatching ReplicaContext.")
raise RuntimeError(
"Mismatching DistributionStrategy objects: %s is not %s." %
"Mismatching tf.distribute.Strategy objects: %s is not %s." %
(context.distribution_strategy, replica_ctx.distribution_strategy))
def _require_distribution_strategy_scope_strategy(distribution_strategy):
"""Verify in a `distribution_strategy.scope()` in this thread."""
def _require_distribution_strategy_scope_strategy(strategy):
"""Verify in a `strategy.scope()` in this thread."""
context = _get_per_thread_mode()
if context.distribution_strategy is distribution_strategy: return
_wrong_distribution_strategy_scope(distribution_strategy, context)
if context.distribution_strategy is strategy: return
_wrong_strategy_scope(strategy, context)
def _require_distribution_strategy_scope_extended(extended):
@ -143,8 +144,8 @@ def _require_distribution_strategy_scope_extended(extended):
context = _get_per_thread_mode()
if context.distribution_strategy.extended is extended: return
# Report error.
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access
_wrong_distribution_strategy_scope(distribution_strategy, context)
strategy = extended._container_strategy() # pylint: disable=protected-access
_wrong_strategy_scope(strategy, context)
# ------------------------------------------------------------------------------
@ -153,15 +154,18 @@ def _require_distribution_strategy_scope_extended(extended):
class _CurrentDistributionContext(object):
"""Context manager for setting the `DistributionStrategy` and var creator."""
"""Context manager setting the current `tf.distribute.Strategy`.
Also: overrides the variable creator and optionally the current device.
"""
def __init__(self,
distribution_strategy,
strategy,
var_creator_scope,
var_scope=None,
default_device=None):
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
distribution_strategy)
strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
if default_device:
@ -190,8 +194,8 @@ class _CurrentDistributionContext(object):
class _SameScopeAgainContext(object):
"""Trivial context manager when you are already in `scope()`."""
def __init__(self, distribution_strategy):
self._distribution_strategy = distribution_strategy
def __init__(self, strategy):
self._distribution_strategy = strategy
def __enter__(self):
return self._distribution_strategy
@ -201,6 +205,7 @@ class _SameScopeAgainContext(object):
# TODO(yuefengz): add more replication modes.
@tf_export("distribute.InputReplicationMode")
class InputReplicationMode(enum.Enum):
"""Replication mode for input function."""
@ -211,6 +216,7 @@ class InputReplicationMode(enum.Enum):
PER_WORKER = "PER_WORKER"
@tf_export("distribute.InputContext")
class InputContext(object):
"""A class wrapping information needed by an input function.
@ -278,6 +284,7 @@ class InputContext(object):
# Base classes for all distribution strategies.
@tf_export("distribute.Strategy")
class DistributionStrategy(object):
"""A list of devices with a state & compute distribution policy.
@ -301,14 +308,14 @@ class DistributionStrategy(object):
@property
def extended(self):
"""`tf.contrib.distribute.DistributionStrategyExtended` with new methods."""
"""`tf.distribute.StrategyExtended` with additional methods."""
return self._extended
def scope(self):
"""Returns a context manager selecting this DistributionStrategy as current.
"""Returns a context manager selecting this Strategy as current.
Inside a `with distribution_strategy.scope():` code block, this thread
will use a variable creator set by `distribution_strategy`, and will
Inside a `with strategy.scope():` code block, this thread
will use a variable creator set by `strategy`, and will
enter its "cross-replica context".
Returns:
@ -330,20 +337,20 @@ class DistributionStrategy(object):
def distribute_dataset(self, dataset_fn):
"""Return a `dataset` split across all replicas. DEPRECATED.
DEPRECATED: Please use `make_dataset_iterator()` or
`make_input_fn_iterator()` instead.
DEPRECATED: Please use `make_dataset_iterator` or
`make_input_fn_iterator` instead.
Suitable for providing input to for `extended.call_for_each_replica()` by
Suitable for providing input to `extended.call_for_each_replica()` by
creating an iterator:
```
def dataset_fn():
return tf.data.Dataset.from_tensors([[1.]]).repeat()
with distribution_strategy.scope():
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
with strategy.scope():
distributed_dataset = strategy.distribute_dataset(dataset_fn)
iterator = distributed_dataset.make_initializable_iterator()
replica_results = distribution_strategy.extended.call_for_each_replica(
replica_results = strategy.extended.call_for_each_replica(
replica_fn, args=(iterator.get_next(),))
```
@ -374,8 +381,8 @@ class DistributionStrategy(object):
replicas.
Returns:
An `InputIterator` which returns inputs for each step of the computation.
User should call `initialize` on the returned iterator.
An `tf.distribute.InputIterator` which returns inputs for each step of the
computation. User should call `initialize` on the returned iterator.
"""
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
@ -384,26 +391,26 @@ class DistributionStrategy(object):
replication_mode=InputReplicationMode.PER_WORKER):
"""Returns an iterator split across replicas created from an input function.
The `input_fn` should take an `InputContext` object where information about
input sharding can be accessed:
The `input_fn` should take an `tf.distribute.InputContext` object where
information about input sharding can be accessed:
```
def input_fn(input_context):
d = tf.data.Dataset.from_tensors([[1.]]).repeat()
return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
with distribution_strategy.scope():
iterator = distribution_strategy.make_input_fn_iterator(
with strategy.scope():
iterator = strategy.make_input_fn_iterator(
input_fn)
replica_results = distribution_strategy.call_for_each_replica(
replica_results = strategy.extended.call_for_each_replica(
replica_fn, iterator.get_next())
```
Args:
input_fn: A function that returns a `tf.data.Dataset`. This function is
expected to take an `InputContext` object.
replication_mode: an enum value of `InputReplicationMode`. Only
`PER_WORKER` is supported currently.
expected to take an `tf.distribute.InputContext` object.
replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
Only `PER_WORKER` is supported currently.
Returns:
An iterator object that can be initialized and fetched next element.
@ -544,7 +551,7 @@ class DistributionStrategy(object):
Args:
value: A value returned by `extended.call_for_each_replica()` or a
variable created in `scope()`.
variable created in `scope`.
Returns:
A list of values contained in `value`. If `value` represents a single
@ -638,17 +645,19 @@ class DistributionStrategy(object):
raise RuntimeError("Must only deepcopy DistributionStrategy.")
@tf_export("distribute.StrategyExtended")
class DistributionStrategyExtended(object):
"""Additional APIs for algorithms that need to be distribution-aware.
The intent is that you can write an algorithm in a stylized way and
it will be usable with a variety of different `DistributionStrategy`
it will be usable with a variety of different
`tf.distribute.Strategy`
implementations. Each descendant will implement a different strategy
for distributing the algorithm across multiple devices/machines.
Furthermore, these changes can be hidden inside the specific layers
and other library classes that need special treatment to run in a
distributed setting, so that most users' model definition code can
run unchanged. The `DistributionStrategy` API works the same way
run unchanged. The `tf.distribute.Strategy` API works the same way
with eager and graph execution.
First let's introduce a few high-level concepts:
@ -696,42 +705,33 @@ class DistributionStrategyExtended(object):
We have then a few approaches we want to support:
* Code written (as if) with no knowledge of class `DistributionStrategy`.
* Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
This code should work as before, even if some of the layers, etc.
used by that code are written to be distribution-aware. This is done
by having a default `DistributionStrategy` that gives ordinary behavior,
by having a default `tf.distribute.Strategy` that gives ordinary behavior,
and by default being in a single replica context.
* Ordinary model code that you want to run using a specific
`DistributionStrategy`. This can be as simple as:
`tf.distribute.Strategy`. This can be as simple as:
```
with my_distribution.scope():
iterator = my_distribution.distribute_dataset(
dataset).make_one_shot_iterator()
replica_train_ops = my_distribution.extended.call_for_each_replica(
with my_strategy.scope():
iterator = my_strategy.make_dataset_iterator(dataset)
session.run(iterator.initialize())
replica_train_ops = my_strategy.extended.call_for_each_replica(
replica_fn, args=(iterator.get_next(),))
train_op = tf.group(my_distribution.unwrap(replica_train_ops))
train_op = my_strategy.group(replica_train_ops)
```
This takes an ordinary `dataset` and `replica_fn` and runs it
distributed using a particular `DistributionStrategy` in
`my_distribution`. Any variables created in `replica_fn` are created
using `my_distribution`'s policy, and library functions called by
distributed using a particular `tf.distribute.Strategy` in
`my_strategy`. Any variables created in `replica_fn` are created
using `my_strategy`'s policy, and library functions called by
`replica_fn` can use the `get_replica_context()` API to get enhanced
behavior in this case.
You can also create an initializable iterator instead of a one-shot
iterator. In that case, you will need to ensure that you initialize the
iterator before calling get_next.
```
iterator = my_distribution.distribute_dataset(
dataset).make_initializable_iterator())
session.run(iterator.initializer)
```
* If you want to write a distributed algorithm, you may use any of
the `DistributionStrategy` APIs inside a
`with my_distribution.scope():` block of code.
the `tf.distribute.Strategy` APIs inside a
`with my_strategy.scope():` block of code.
Lower-level concepts:
@ -758,7 +758,7 @@ class DistributionStrategyExtended(object):
* Replica context vs. Cross-replica context: _replica context_ is when we
are in some function that is being called once for each replica.
Otherwise we are in cross-replica context, which is useful for
calling `DistributionStrategy` methods which operate across the
calling `tf.distribute.Strategy` methods which operate across the
replicas (like `reduce_to()`). By default you start in a replica context
(the default "single replica context") and then some methods can
switch you back and forth, as described below.
@ -778,7 +778,7 @@ class DistributionStrategyExtended(object):
pick a consistent set of devices to pass to both
`colocate_vars_with()` and `update_non_slot()`.
When using a `DistributionStrategy`, we have a new type dimension
When using a `tf.distribute.Strategy`, we have a new type dimension
called _locality_ that says what values are compatible with which
APIs:
@ -856,19 +856,20 @@ class DistributionStrategyExtended(object):
input and destination.
Layers should expect to be called in a replica context, and can use
the `get_replica_context()` function to get a `ReplicaContext` object. The
the `tf.distribute.get_replica_context` function to get a
`tf.distribute.ReplicaContext` object. The
`ReplicaContext` object has a `merge_call()` method for entering
cross-replica context where you can use `reduce_to()` (or
`batch_reduce_to()`) and then optionally `update()` to update state.
You may use this API whether or not a `DistributionStrategy` is
You may use this API whether or not a `tf.distribute.Strategy` is
being used, since there is a default implementation of
`ReplicaContext` and `DistributionStrategy`.
`ReplicaContext` and `tf.distribute.Strategy`.
NOTE for new `DistributionStrategy` implementations: Please put all logic
in a subclass of `DistributionStrategyExtended`. The only code needed for
the `DistributionStrategy` subclass is for instantiating your subclass of
`DistributionStrategyExtended` in the `__init__` method.
NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
the `tf.distribute.Strategy` subclass is for instantiating your subclass of
`tf.distribute.StrategyExtended` in the `__init__` method.
"""
def __init__(self, container_strategy):
@ -908,7 +909,7 @@ class DistributionStrategyExtended(object):
if kwargs.pop("partitioner", None) is not None:
tf_logging.log_first_n(
tf_logging.WARN, "Partitioned variables are disabled when using "
"current DistributionStrategy.", 1)
"current tf.distribute.Strategy.", 1)
return getter(*args, **kwargs)
return _CurrentDistributionContext(
@ -932,7 +933,7 @@ class DistributionStrategyExtended(object):
(read-only) value of any other variable.
Args:
v: A variable allocated within the scope of this `DistributionStrategy`.
v: A variable allocated within the scope of this `tf.distribute.Strategy`.
Returns:
A tensor representing the value of `v`, aggregated across replicas if
@ -953,9 +954,9 @@ class DistributionStrategyExtended(object):
Example usage:
```
with distribution_strategy.scope():
with strategy.scope():
var1 = tf.get_variable(...)
with distribution_strategy.colocate_vars_with(v1):
with strategy.extended.colocate_vars_with(v1):
# var2 and var3 will be created on the same device(s) as var1
var2 = tf.get_variable(...)
var3 = tf.get_variable(...)
@ -964,7 +965,7 @@ class DistributionStrategyExtended(object):
# operates on v1 from var1, v2 from var2, and v3 from var3
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
distribution_strategy.update(v1, fn, args=(v2, v3))
strategy.extended.update(v1, fn, args=(v2, v3))
```
Args:
@ -990,7 +991,7 @@ class DistributionStrategyExtended(object):
if not isinstance(result, dataset_ops.Dataset):
raise ValueError(
"dataset_fn() must return a tf.data.Dataset when using a "
"DistributionStrategy.")
"tf.distribute.Strategy.")
return result
# TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
@ -1320,7 +1321,7 @@ class DistributionStrategyExtended(object):
Args:
var_list: The list of variables being optimized, needed with the
default `DistributionStrategy`.
default `tf.distribute.Strategy`.
"""
raise NotImplementedError("must be implemented in descendants")
@ -1374,13 +1375,18 @@ class DistributionStrategyExtended(object):
# around their model creation and graph definition. There is no
# anticipated need to define descendants of _CurrentDistributionContext.
# It sets the current DistributionStrategy for purposes of
# `get_distribution_strategy()` and `has_distribution_strategy()`
# `get_strategy()` and `has_strategy()`
# and switches the thread mode to a "cross-replica context".
@tf_export("distribute.ReplicaContext")
class ReplicaContext(object):
"""DistributionStrategy API inside a `call_for_each_replica()` call."""
"""`tf.distribute.Strategy` API when in a replica context.
def __init__(self, distribution_strategy, replica_id_in_sync_group):
self._distribution_strategy = distribution_strategy
To be used inside your replicated step function, such as in a
`tf.distribute.StrategyExtended.call_for_each_replica` call.
"""
def __init__(self, strategy, replica_id_in_sync_group):
self._distribution_strategy = strategy
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
self)
self._replica_id_in_sync_group = replica_id_in_sync_group
@ -1396,23 +1402,24 @@ class ReplicaContext(object):
This allows communication and coordination when there are multiple calls
to a model function triggered by a call to
`distribution.call_for_each_replica(model_fn, ...)`.
`strategy.extended.call_for_each_replica(model_fn, ...)`.
See `MirroredDistribution.call_for_each_replica()` for an explanation.
See `tf.distribute.StrategyExtended.call_for_each_replica` for an
explanation.
Otherwise, this is equivalent to:
If not inside a distributed scope, this is equivalent to:
```
distribution = get_distribution_strategy()
with cross-replica-context(distribution):
return merge_fn(distribution, *args, **kwargs)
strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
return merge_fn(strategy, *args, **kwargs)
```
Args:
merge_fn: function that joins arguments from threads that are given as
PerReplica. It accepts `DistributionStrategy` object as the first
argument.
args: List or tuple with positional per-thread arguments for `merge_fn`
PerReplica. It accepts `tf.distribute.Strategy` object as
the first argument.
args: List or tuple with positional per-thread arguments for `merge_fn`.
kwargs: Dict with keyword per-thread arguments for `merge_fn`.
Returns:
@ -1446,8 +1453,14 @@ class ReplicaContext(object):
return self._replica_id_in_sync_group
@property
@doc_controls.do_not_generate_docs # DEPRECATED, use `strategy`
def distribution_strategy(self):
"""The current `DistributionStrategy` object."""
"""DEPRECATED: use `self.stratgey` instead."""
return self._distribution_strategy
@property
def strategy(self):
"""The current `tf.distribute.Strategy` object."""
return self._distribution_strategy
@property
@ -1470,7 +1483,7 @@ class ReplicaContext(object):
class _DefaultDistributionStrategy(DistributionStrategy):
"""Default `DistributionStrategy` if none is explicitly selected."""
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
def __init__(self):
super(_DefaultDistributionStrategy, self).__init__(
@ -1483,7 +1496,7 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
def _scope(self, strategy):
"""Context manager setting a variable creator and `self` as current."""
if distribution_strategy_context.has_distribution_strategy():
raise RuntimeError("Must not nest DistributionStrategy scopes.")
raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope_strategy(strategy)
@ -1555,13 +1568,13 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
@property
def worker_devices(self):
raise RuntimeError(
"worker_devices() method unsupported by _DefaultDistributionStrategy.")
raise RuntimeError("worker_devices() method unsupported by default "
"tf.distribute.Strategy.")
@property
def parameter_devices(self):
raise RuntimeError("parameter_devices() method unsupported by "
"_DefaultDistributionStrategy.")
raise RuntimeError("parameter_devices() method unsupported by default "
"tf.distribute.Strategy.")
def non_slot_devices(self, var_list):
return min(var_list, key=lambda x: x.name)
@ -1600,8 +1613,8 @@ _original_from_proto = resource_variable_ops._from_proto_fn
def _from_proto_fn(v, import_scope=None):
if distribution_strategy_context.has_distribution_strategy():
raise NotImplementedError(
"Deserialization of variables is not yet supported when using"
"distributed strategies.")
"Deserialization of variables is not yet supported when using a "
"tf.distribute.Strategy.")
else:
return _original_from_proto(v, import_scope=import_scope)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# There is a circular dependency between this and `distribute` module. So we
@ -85,6 +86,7 @@ def _get_per_thread_mode():
# Public API for accessing the current thread mode
@tf_export("distribute.get_replica_context")
def get_replica_context():
"""Returns the current `tf.distribute.ReplicaContext` or `None`.
@ -95,7 +97,7 @@ def get_replica_context():
1. starts in the default (single-replica) replica context (this function
will return the default `ReplicaContext` object);
2. switches to cross-replica context (in which case this will return
`None`) when entering a `with DistributionStrategy.scope():` block;
`None`) when entering a `with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside
`extended.call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
@ -103,11 +105,11 @@ def get_replica_context():
this function will return `None`).
Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `DistributionStrategy`. You may
cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by
calling `extended.call_for_each_replica()`, jumping back to step 3.
Most `DistributionStrategy` methods may only be executed in
Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context, in a replica context you should use the
`ReplicaContext` API instead.
@ -124,7 +126,7 @@ def get_replica_context():
def get_cross_replica_context():
"""Returns the current DistributionStrategy if in a cross-replica context.
"""Returns the current tf.distribute.Strategy if in a cross-replica context.
DEPRECATED: Please use `in_cross_replica_context()` and
`get_distribution_strategy()` instead.
@ -133,22 +135,22 @@ def get_cross_replica_context():
1. starts in the default (single-replica) replica context;
2. switches to cross-replica context when entering a
`with DistributionStrategy.scope():` block;
`with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside
`call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-replica context.
Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `DistributionStrategy`. You may
cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by
calling `call_for_each_replica()`, jumping back to step 3.
Most `DistributionStrategy` methods may only be executed in
Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context.
Returns:
Returns the current `DistributionStrategy` object in a cross-replica
Returns the current `tf.distribute.Strategy` object in a cross-replica
context, or `None`.
Exactly one of `get_replica_context()` and `get_cross_replica_context()`
@ -157,6 +159,7 @@ def get_cross_replica_context():
return _get_per_thread_mode().cross_replica_context
@tf_export("distribute.in_cross_replica_context")
def in_cross_replica_context():
"""Returns True if in a cross-replica context.
@ -170,31 +173,33 @@ def in_cross_replica_context():
return _get_per_thread_mode().cross_replica_context is not None
@tf_export("distribute.get_strategy")
def get_distribution_strategy():
"""Returns the current `DistributionStrategy` object.
"""Returns the current `tf.distribute.Strategy` object.
Typically only used in a cross-replica context:
```
if tf.distribute.in_cross_replica_context():
strategy = tf.distribute.get_distribution_strategy()
strategy = tf.distribute.get_strategy()
...
```
Returns:
A `DistributionStrategy` object. Inside a
A `tf.distribute.Strategy` object. Inside a
`with distribution_strategy.scope()` block, it returns
`distribution_strategy`, otherwise it returns the default
(single-replica) `DistributionStrategy` object.
(single-replica) `tf.distribute.Strategy` object.
"""
return _get_per_thread_mode().distribution_strategy
@tf_export("distribute.has_strategy")
def has_distribution_strategy():
"""Return if there is a current non-default `DistributionStrategy`.
"""Return if there is a current non-default `tf.distribute.Strategy`.
Returns:
True if inside a `with distribution_strategy.scope():`.
True if inside a `with strategy.scope():`.
"""
return get_distribution_strategy() is not _get_default_distribution_strategy()

View File

@ -0,0 +1,25 @@
path: "tensorflow.distribute.InputContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
is_instance: "<type \'object\'>"
member {
name: "input_pipeline_id"
mtype: "<type \'property\'>"
}
member {
name: "num_input_pipelines"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
}
member_method {
name: "get_per_replica_batch_size"
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,8 @@
path: "tensorflow.distribute.InputReplicationMode"
tf_class {
is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>"
}
}

View File

@ -0,0 +1,12 @@
path: "tensorflow.distribute.ReduceOp"
tf_class {
is_instance: "<enum \'ReduceOp\'>"
member {
name: "MEAN"
mtype: "<enum \'ReduceOp\'>"
}
member {
name: "SUM"
mtype: "<enum \'ReduceOp\'>"
}
}

View File

@ -0,0 +1,33 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"
mtype: "<type \'property\'>"
}
member {
name: "distribution_strategy"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "replica_id_in_sync_group"
mtype: "<type \'property\'>"
}
member {
name: "strategy"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge_call"
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
}

View File

@ -0,0 +1,81 @@
path: "tensorflow.distribute.StrategyExtended"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_between_graph"
mtype: "<type \'property\'>"
}
member {
name: "experimental_require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "experimental_should_init"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast_to"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run_steps_on_iterator"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,133 @@
path: "tensorflow.distribute.Strategy"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
is_instance: "<type \'object\'>"
member {
name: "between_graph"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_init"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "distribute_dataset"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "run_steps_on_dataset"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,47 @@
path: "tensorflow.distribute"
tf_module {
member {
name: "InputContext"
mtype: "<type \'type\'>"
}
member {
name: "InputReplicationMode"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReduceOp"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReplicaContext"
mtype: "<type \'type\'>"
}
member {
name: "Strategy"
mtype: "<type \'type\'>"
}
member {
name: "StrategyExtended"
mtype: "<type \'type\'>"
}
member_method {
name: "get_loss_reduction"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "has_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "in_cross_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -324,6 +324,10 @@ tf_module {
name: "debugging"
mtype: "<type \'module\'>"
}
member {
name: "distribute"
mtype: "<type \'module\'>"
}
member {
name: "distributions"
mtype: "<type \'module\'>"

View File

@ -0,0 +1,25 @@
path: "tensorflow.distribute.InputContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
is_instance: "<type \'object\'>"
member {
name: "input_pipeline_id"
mtype: "<type \'property\'>"
}
member {
name: "num_input_pipelines"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
}
member_method {
name: "get_per_replica_batch_size"
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,8 @@
path: "tensorflow.distribute.InputReplicationMode"
tf_class {
is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>"
}
}

View File

@ -0,0 +1,12 @@
path: "tensorflow.distribute.ReduceOp"
tf_class {
is_instance: "<enum \'ReduceOp\'>"
member {
name: "MEAN"
mtype: "<enum \'ReduceOp\'>"
}
member {
name: "SUM"
mtype: "<enum \'ReduceOp\'>"
}
}

View File

@ -0,0 +1,33 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"
mtype: "<type \'property\'>"
}
member {
name: "distribution_strategy"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "replica_id_in_sync_group"
mtype: "<type \'property\'>"
}
member {
name: "strategy"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge_call"
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
}

View File

@ -0,0 +1,81 @@
path: "tensorflow.distribute.StrategyExtended"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_between_graph"
mtype: "<type \'property\'>"
}
member {
name: "experimental_require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "experimental_should_init"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast_to"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run_steps_on_iterator"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,133 @@
path: "tensorflow.distribute.Strategy"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
is_instance: "<type \'object\'>"
member {
name: "between_graph"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_init"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "distribute_dataset"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "run_steps_on_dataset"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,47 @@
path: "tensorflow.distribute"
tf_module {
member {
name: "InputContext"
mtype: "<type \'type\'>"
}
member {
name: "InputReplicationMode"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReduceOp"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReplicaContext"
mtype: "<type \'type\'>"
}
member {
name: "Strategy"
mtype: "<type \'type\'>"
}
member {
name: "StrategyExtended"
mtype: "<type \'type\'>"
}
member_method {
name: "get_loss_reduction"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "has_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "in_cross_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -180,6 +180,10 @@ tf_module {
name: "debugging"
mtype: "<type \'module\'>"
}
member {
name: "distribute"
mtype: "<type \'module\'>"
}
member {
name: "double"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"