Create tf.distribute namespace with the base classes and standard APIs

for distribution strategies. Strategy implementations will be in a
future change. Also rename DistributionStrategy to tf.distribute.Strategy,
and other changes to match.

RELNOTES: Expose tf.distribute.Strategy as the new name for
tf.contrib.distribute.DistributionStrategy.
PiperOrigin-RevId: 222097121
This commit is contained in:
A. Unique TensorFlower 2018-11-19 10:21:43 -08:00 committed by TensorFlower Gardener
parent 5c60fb7e9b
commit a4bcedd676
23 changed files with 837 additions and 126 deletions

View File

@ -216,7 +216,10 @@ py_library(
py_library( py_library(
name = "reduce_util", name = "reduce_util",
srcs = ["reduce_util.py"], srcs = ["reduce_util.py"],
deps = [], deps = [
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
],
) )
py_library( py_library(

View File

@ -21,9 +21,10 @@ from __future__ import print_function
import enum import enum
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.util.tf_export import tf_export
# TODO(priyag): Add this to tf.distribute namespace when it exists. @tf_export("distribute.ReduceOp")
class ReduceOp(enum.Enum): class ReduceOp(enum.Enum):
"""Indicates how a set of values should be reduced. """Indicates how a set of values should be reduced.

View File

@ -9,6 +9,7 @@ TENSORFLOW_API_INIT_FILES = [
"data/__init__.py", "data/__init__.py",
"data/experimental/__init__.py", "data/experimental/__init__.py",
"debugging/__init__.py", "debugging/__init__.py",
"distribute/__init__.py",
"dtypes/__init__.py", "dtypes/__init__.py",
"errors/__init__.py", "errors/__init__.py",
"experimental/__init__.py", "experimental/__init__.py",

View File

@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"data/__init__.py", "data/__init__.py",
"data/experimental/__init__.py", "data/experimental/__init__.py",
"debugging/__init__.py", "debugging/__init__.py",
"distribute/__init__.py",
"distributions/__init__.py", "distributions/__init__.py",
"dtypes/__init__.py", "dtypes/__init__.py",
"errors/__init__.py", "errors/__init__.py",

View File

@ -35,10 +35,11 @@ DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
_TENSORFLOW_DOC_SOURCES = { _TENSORFLOW_DOC_SOURCES = {
'app': DocSource(docstring_module_name='platform.app'), 'app': DocSource(docstring_module_name='platform.app'),
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
'compat': DocSource(docstring_module_name='util.compat'), 'compat': DocSource(docstring_module_name='util.compat'),
'distribute': DocSource(docstring_module_name='training.distribute'),
'distributions': DocSource( 'distributions': DocSource(
docstring_module_name='ops.distributions.distributions'), docstring_module_name='ops.distributions.distributions'),
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
'errors': DocSource(docstring_module_name='framework.errors'), 'errors': DocSource(docstring_module_name='framework.errors'),
'gfile': DocSource(docstring_module_name='platform.gfile'), 'gfile': DocSource(docstring_module_name='platform.gfile'),
'graph_util': DocSource(docstring_module_name='framework.graph_util'), 'graph_util': DocSource(docstring_module_name='framework.graph_util'),

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Class DistributionStrategy, ReplicaContext, and supporting APIs.""" """Library for running a computation across multiple devices."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -38,19 +38,19 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls from tensorflow.tools.docs import doc_controls
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Context tracking whether in a distribution.update() or .update_non_slot() # Context tracking whether in a strategy.update() or .update_non_slot() call.
# call.
_update_device = threading.local() _update_device = threading.local()
def get_update_device(): def get_update_device():
"""Get the current device if in a `DistributionStrategy.update()` call.""" """Get the current device if in a `tf.distribute.Strategy.update()` call."""
try: try:
return _update_device.current return _update_device.current
except AttributeError: except AttributeError:
@ -77,8 +77,9 @@ class UpdateContext(object):
# Public utility functions. # Public utility functions.
@tf_export("distribute.get_loss_reduction")
def get_loss_reduction(): def get_loss_reduction():
"""Reduce op corresponding to the last loss reduction.""" """`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
if loss_reduction == losses_impl.Reduction.SUM: if loss_reduction == losses_impl.Reduction.SUM:
return reduce_util.ReduceOp.SUM return reduce_util.ReduceOp.SUM
@ -95,25 +96,25 @@ def _require_cross_replica_context_extended(extended):
cross_replica = context.cross_replica_context cross_replica = context.cross_replica_context
if cross_replica is not None and cross_replica.extended is extended: if cross_replica is not None and cross_replica.extended is extended:
return return
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access strategy = extended._container_strategy() # pylint: disable=protected-access
# We have an error to report, figure out the right message. # We have an error to report, figure out the right message.
if context.distribution_strategy is not distribution_strategy: if context.distribution_strategy is not strategy:
_wrong_distribution_strategy_scope(distribution_strategy, context) _wrong_strategy_scope(strategy, context)
assert cross_replica is None assert cross_replica is None
raise RuntimeError("Method requires being in cross-replica context, use " raise RuntimeError("Method requires being in cross-replica context, use "
"get_replica_context().merge_call()") "get_replica_context().merge_call()")
def _wrong_distribution_strategy_scope(distribution_strategy, context): def _wrong_strategy_scope(strategy, context):
# Figure out the right error message. # Figure out the right error message.
if not distribution_strategy_context.has_distribution_strategy(): if not distribution_strategy_context.has_distribution_strategy():
raise RuntimeError( raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' % 'Need to be inside "with strategy.scope()" for %s' %
(distribution_strategy,)) (strategy,))
else: else:
raise RuntimeError( raise RuntimeError(
"Mixing different DistributionStrategy objects: %s is not %s" % "Mixing different tf.distribute.Strategy objects: %s is not %s" %
(context.distribution_strategy, distribution_strategy)) (context.distribution_strategy, strategy))
def require_replica_context(replica_ctx): def require_replica_context(replica_ctx):
@ -124,18 +125,18 @@ def require_replica_context(replica_ctx):
if context.replica_context is None: if context.replica_context is None:
raise RuntimeError("Need to be inside `call_for_each_replica()`") raise RuntimeError("Need to be inside `call_for_each_replica()`")
if context.distribution_strategy is replica_ctx.distribution_strategy: if context.distribution_strategy is replica_ctx.distribution_strategy:
# Two different ReplicaContexts with the same DistributionStrategy. # Two different ReplicaContexts with the same tf.distribute.Strategy.
raise RuntimeError("Mismatching ReplicaContext.") raise RuntimeError("Mismatching ReplicaContext.")
raise RuntimeError( raise RuntimeError(
"Mismatching DistributionStrategy objects: %s is not %s." % "Mismatching tf.distribute.Strategy objects: %s is not %s." %
(context.distribution_strategy, replica_ctx.distribution_strategy)) (context.distribution_strategy, replica_ctx.distribution_strategy))
def _require_distribution_strategy_scope_strategy(distribution_strategy): def _require_distribution_strategy_scope_strategy(strategy):
"""Verify in a `distribution_strategy.scope()` in this thread.""" """Verify in a `strategy.scope()` in this thread."""
context = _get_per_thread_mode() context = _get_per_thread_mode()
if context.distribution_strategy is distribution_strategy: return if context.distribution_strategy is strategy: return
_wrong_distribution_strategy_scope(distribution_strategy, context) _wrong_strategy_scope(strategy, context)
def _require_distribution_strategy_scope_extended(extended): def _require_distribution_strategy_scope_extended(extended):
@ -143,8 +144,8 @@ def _require_distribution_strategy_scope_extended(extended):
context = _get_per_thread_mode() context = _get_per_thread_mode()
if context.distribution_strategy.extended is extended: return if context.distribution_strategy.extended is extended: return
# Report error. # Report error.
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access strategy = extended._container_strategy() # pylint: disable=protected-access
_wrong_distribution_strategy_scope(distribution_strategy, context) _wrong_strategy_scope(strategy, context)
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
@ -153,15 +154,18 @@ def _require_distribution_strategy_scope_extended(extended):
class _CurrentDistributionContext(object): class _CurrentDistributionContext(object):
"""Context manager for setting the `DistributionStrategy` and var creator.""" """Context manager setting the current `tf.distribute.Strategy`.
Also: overrides the variable creator and optionally the current device.
"""
def __init__(self, def __init__(self,
distribution_strategy, strategy,
var_creator_scope, var_creator_scope,
var_scope=None, var_scope=None,
default_device=None): default_device=None):
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
distribution_strategy) strategy)
self._var_creator_scope = var_creator_scope self._var_creator_scope = var_creator_scope
self._var_scope = var_scope self._var_scope = var_scope
if default_device: if default_device:
@ -190,8 +194,8 @@ class _CurrentDistributionContext(object):
class _SameScopeAgainContext(object): class _SameScopeAgainContext(object):
"""Trivial context manager when you are already in `scope()`.""" """Trivial context manager when you are already in `scope()`."""
def __init__(self, distribution_strategy): def __init__(self, strategy):
self._distribution_strategy = distribution_strategy self._distribution_strategy = strategy
def __enter__(self): def __enter__(self):
return self._distribution_strategy return self._distribution_strategy
@ -201,6 +205,7 @@ class _SameScopeAgainContext(object):
# TODO(yuefengz): add more replication modes. # TODO(yuefengz): add more replication modes.
@tf_export("distribute.InputReplicationMode")
class InputReplicationMode(enum.Enum): class InputReplicationMode(enum.Enum):
"""Replication mode for input function.""" """Replication mode for input function."""
@ -211,6 +216,7 @@ class InputReplicationMode(enum.Enum):
PER_WORKER = "PER_WORKER" PER_WORKER = "PER_WORKER"
@tf_export("distribute.InputContext")
class InputContext(object): class InputContext(object):
"""A class wrapping information needed by an input function. """A class wrapping information needed by an input function.
@ -278,6 +284,7 @@ class InputContext(object):
# Base classes for all distribution strategies. # Base classes for all distribution strategies.
@tf_export("distribute.Strategy")
class DistributionStrategy(object): class DistributionStrategy(object):
"""A list of devices with a state & compute distribution policy. """A list of devices with a state & compute distribution policy.
@ -301,14 +308,14 @@ class DistributionStrategy(object):
@property @property
def extended(self): def extended(self):
"""`tf.contrib.distribute.DistributionStrategyExtended` with new methods.""" """`tf.distribute.StrategyExtended` with additional methods."""
return self._extended return self._extended
def scope(self): def scope(self):
"""Returns a context manager selecting this DistributionStrategy as current. """Returns a context manager selecting this Strategy as current.
Inside a `with distribution_strategy.scope():` code block, this thread Inside a `with strategy.scope():` code block, this thread
will use a variable creator set by `distribution_strategy`, and will will use a variable creator set by `strategy`, and will
enter its "cross-replica context". enter its "cross-replica context".
Returns: Returns:
@ -330,20 +337,20 @@ class DistributionStrategy(object):
def distribute_dataset(self, dataset_fn): def distribute_dataset(self, dataset_fn):
"""Return a `dataset` split across all replicas. DEPRECATED. """Return a `dataset` split across all replicas. DEPRECATED.
DEPRECATED: Please use `make_dataset_iterator()` or DEPRECATED: Please use `make_dataset_iterator` or
`make_input_fn_iterator()` instead. `make_input_fn_iterator` instead.
Suitable for providing input to for `extended.call_for_each_replica()` by Suitable for providing input to `extended.call_for_each_replica()` by
creating an iterator: creating an iterator:
``` ```
def dataset_fn(): def dataset_fn():
return tf.data.Dataset.from_tensors([[1.]]).repeat() return tf.data.Dataset.from_tensors([[1.]]).repeat()
with distribution_strategy.scope(): with strategy.scope():
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn) distributed_dataset = strategy.distribute_dataset(dataset_fn)
iterator = distributed_dataset.make_initializable_iterator() iterator = distributed_dataset.make_initializable_iterator()
replica_results = distribution_strategy.extended.call_for_each_replica( replica_results = strategy.extended.call_for_each_replica(
replica_fn, args=(iterator.get_next(),)) replica_fn, args=(iterator.get_next(),))
``` ```
@ -374,8 +381,8 @@ class DistributionStrategy(object):
replicas. replicas.
Returns: Returns:
An `InputIterator` which returns inputs for each step of the computation. An `tf.distribute.InputIterator` which returns inputs for each step of the
User should call `initialize` on the returned iterator. computation. User should call `initialize` on the returned iterator.
""" """
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
@ -384,26 +391,26 @@ class DistributionStrategy(object):
replication_mode=InputReplicationMode.PER_WORKER): replication_mode=InputReplicationMode.PER_WORKER):
"""Returns an iterator split across replicas created from an input function. """Returns an iterator split across replicas created from an input function.
The `input_fn` should take an `InputContext` object where information about The `input_fn` should take an `tf.distribute.InputContext` object where
input sharding can be accessed: information about input sharding can be accessed:
``` ```
def input_fn(input_context): def input_fn(input_context):
d = tf.data.Dataset.from_tensors([[1.]]).repeat() d = tf.data.Dataset.from_tensors([[1.]]).repeat()
return d.shard(input_context.num_input_pipelines, return d.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id) input_context.input_pipeline_id)
with distribution_strategy.scope(): with strategy.scope():
iterator = distribution_strategy.make_input_fn_iterator( iterator = strategy.make_input_fn_iterator(
input_fn) input_fn)
replica_results = distribution_strategy.call_for_each_replica( replica_results = strategy.extended.call_for_each_replica(
replica_fn, iterator.get_next()) replica_fn, iterator.get_next())
``` ```
Args: Args:
input_fn: A function that returns a `tf.data.Dataset`. This function is input_fn: A function that returns a `tf.data.Dataset`. This function is
expected to take an `InputContext` object. expected to take an `tf.distribute.InputContext` object.
replication_mode: an enum value of `InputReplicationMode`. Only replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
`PER_WORKER` is supported currently. Only `PER_WORKER` is supported currently.
Returns: Returns:
An iterator object that can be initialized and fetched next element. An iterator object that can be initialized and fetched next element.
@ -544,7 +551,7 @@ class DistributionStrategy(object):
Args: Args:
value: A value returned by `extended.call_for_each_replica()` or a value: A value returned by `extended.call_for_each_replica()` or a
variable created in `scope()`. variable created in `scope`.
Returns: Returns:
A list of values contained in `value`. If `value` represents a single A list of values contained in `value`. If `value` represents a single
@ -638,17 +645,19 @@ class DistributionStrategy(object):
raise RuntimeError("Must only deepcopy DistributionStrategy.") raise RuntimeError("Must only deepcopy DistributionStrategy.")
@tf_export("distribute.StrategyExtended")
class DistributionStrategyExtended(object): class DistributionStrategyExtended(object):
"""Additional APIs for algorithms that need to be distribution-aware. """Additional APIs for algorithms that need to be distribution-aware.
The intent is that you can write an algorithm in a stylized way and The intent is that you can write an algorithm in a stylized way and
it will be usable with a variety of different `DistributionStrategy` it will be usable with a variety of different
`tf.distribute.Strategy`
implementations. Each descendant will implement a different strategy implementations. Each descendant will implement a different strategy
for distributing the algorithm across multiple devices/machines. for distributing the algorithm across multiple devices/machines.
Furthermore, these changes can be hidden inside the specific layers Furthermore, these changes can be hidden inside the specific layers
and other library classes that need special treatment to run in a and other library classes that need special treatment to run in a
distributed setting, so that most users' model definition code can distributed setting, so that most users' model definition code can
run unchanged. The `DistributionStrategy` API works the same way run unchanged. The `tf.distribute.Strategy` API works the same way
with eager and graph execution. with eager and graph execution.
First let's introduce a few high-level concepts: First let's introduce a few high-level concepts:
@ -696,42 +705,33 @@ class DistributionStrategyExtended(object):
We have then a few approaches we want to support: We have then a few approaches we want to support:
* Code written (as if) with no knowledge of class `DistributionStrategy`. * Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
This code should work as before, even if some of the layers, etc. This code should work as before, even if some of the layers, etc.
used by that code are written to be distribution-aware. This is done used by that code are written to be distribution-aware. This is done
by having a default `DistributionStrategy` that gives ordinary behavior, by having a default `tf.distribute.Strategy` that gives ordinary behavior,
and by default being in a single replica context. and by default being in a single replica context.
* Ordinary model code that you want to run using a specific * Ordinary model code that you want to run using a specific
`DistributionStrategy`. This can be as simple as: `tf.distribute.Strategy`. This can be as simple as:
``` ```
with my_distribution.scope(): with my_strategy.scope():
iterator = my_distribution.distribute_dataset( iterator = my_strategy.make_dataset_iterator(dataset)
dataset).make_one_shot_iterator() session.run(iterator.initialize())
replica_train_ops = my_distribution.extended.call_for_each_replica( replica_train_ops = my_strategy.extended.call_for_each_replica(
replica_fn, args=(iterator.get_next(),)) replica_fn, args=(iterator.get_next(),))
train_op = tf.group(my_distribution.unwrap(replica_train_ops)) train_op = my_strategy.group(replica_train_ops)
``` ```
This takes an ordinary `dataset` and `replica_fn` and runs it This takes an ordinary `dataset` and `replica_fn` and runs it
distributed using a particular `DistributionStrategy` in distributed using a particular `tf.distribute.Strategy` in
`my_distribution`. Any variables created in `replica_fn` are created `my_strategy`. Any variables created in `replica_fn` are created
using `my_distribution`'s policy, and library functions called by using `my_strategy`'s policy, and library functions called by
`replica_fn` can use the `get_replica_context()` API to get enhanced `replica_fn` can use the `get_replica_context()` API to get enhanced
behavior in this case. behavior in this case.
You can also create an initializable iterator instead of a one-shot
iterator. In that case, you will need to ensure that you initialize the
iterator before calling get_next.
```
iterator = my_distribution.distribute_dataset(
dataset).make_initializable_iterator())
session.run(iterator.initializer)
```
* If you want to write a distributed algorithm, you may use any of * If you want to write a distributed algorithm, you may use any of
the `DistributionStrategy` APIs inside a the `tf.distribute.Strategy` APIs inside a
`with my_distribution.scope():` block of code. `with my_strategy.scope():` block of code.
Lower-level concepts: Lower-level concepts:
@ -758,7 +758,7 @@ class DistributionStrategyExtended(object):
* Replica context vs. Cross-replica context: _replica context_ is when we * Replica context vs. Cross-replica context: _replica context_ is when we
are in some function that is being called once for each replica. are in some function that is being called once for each replica.
Otherwise we are in cross-replica context, which is useful for Otherwise we are in cross-replica context, which is useful for
calling `DistributionStrategy` methods which operate across the calling `tf.distribute.Strategy` methods which operate across the
replicas (like `reduce_to()`). By default you start in a replica context replicas (like `reduce_to()`). By default you start in a replica context
(the default "single replica context") and then some methods can (the default "single replica context") and then some methods can
switch you back and forth, as described below. switch you back and forth, as described below.
@ -778,7 +778,7 @@ class DistributionStrategyExtended(object):
pick a consistent set of devices to pass to both pick a consistent set of devices to pass to both
`colocate_vars_with()` and `update_non_slot()`. `colocate_vars_with()` and `update_non_slot()`.
When using a `DistributionStrategy`, we have a new type dimension When using a `tf.distribute.Strategy`, we have a new type dimension
called _locality_ that says what values are compatible with which called _locality_ that says what values are compatible with which
APIs: APIs:
@ -856,19 +856,20 @@ class DistributionStrategyExtended(object):
input and destination. input and destination.
Layers should expect to be called in a replica context, and can use Layers should expect to be called in a replica context, and can use
the `get_replica_context()` function to get a `ReplicaContext` object. The the `tf.distribute.get_replica_context` function to get a
`tf.distribute.ReplicaContext` object. The
`ReplicaContext` object has a `merge_call()` method for entering `ReplicaContext` object has a `merge_call()` method for entering
cross-replica context where you can use `reduce_to()` (or cross-replica context where you can use `reduce_to()` (or
`batch_reduce_to()`) and then optionally `update()` to update state. `batch_reduce_to()`) and then optionally `update()` to update state.
You may use this API whether or not a `DistributionStrategy` is You may use this API whether or not a `tf.distribute.Strategy` is
being used, since there is a default implementation of being used, since there is a default implementation of
`ReplicaContext` and `DistributionStrategy`. `ReplicaContext` and `tf.distribute.Strategy`.
NOTE for new `DistributionStrategy` implementations: Please put all logic NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
in a subclass of `DistributionStrategyExtended`. The only code needed for in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
the `DistributionStrategy` subclass is for instantiating your subclass of the `tf.distribute.Strategy` subclass is for instantiating your subclass of
`DistributionStrategyExtended` in the `__init__` method. `tf.distribute.StrategyExtended` in the `__init__` method.
""" """
def __init__(self, container_strategy): def __init__(self, container_strategy):
@ -908,7 +909,7 @@ class DistributionStrategyExtended(object):
if kwargs.pop("partitioner", None) is not None: if kwargs.pop("partitioner", None) is not None:
tf_logging.log_first_n( tf_logging.log_first_n(
tf_logging.WARN, "Partitioned variables are disabled when using " tf_logging.WARN, "Partitioned variables are disabled when using "
"current DistributionStrategy.", 1) "current tf.distribute.Strategy.", 1)
return getter(*args, **kwargs) return getter(*args, **kwargs)
return _CurrentDistributionContext( return _CurrentDistributionContext(
@ -932,7 +933,7 @@ class DistributionStrategyExtended(object):
(read-only) value of any other variable. (read-only) value of any other variable.
Args: Args:
v: A variable allocated within the scope of this `DistributionStrategy`. v: A variable allocated within the scope of this `tf.distribute.Strategy`.
Returns: Returns:
A tensor representing the value of `v`, aggregated across replicas if A tensor representing the value of `v`, aggregated across replicas if
@ -953,9 +954,9 @@ class DistributionStrategyExtended(object):
Example usage: Example usage:
``` ```
with distribution_strategy.scope(): with strategy.scope():
var1 = tf.get_variable(...) var1 = tf.get_variable(...)
with distribution_strategy.colocate_vars_with(v1): with strategy.extended.colocate_vars_with(v1):
# var2 and var3 will be created on the same device(s) as var1 # var2 and var3 will be created on the same device(s) as var1
var2 = tf.get_variable(...) var2 = tf.get_variable(...)
var3 = tf.get_variable(...) var3 = tf.get_variable(...)
@ -964,7 +965,7 @@ class DistributionStrategyExtended(object):
# operates on v1 from var1, v2 from var2, and v3 from var3 # operates on v1 from var1, v2 from var2, and v3 from var3
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too. # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
distribution_strategy.update(v1, fn, args=(v2, v3)) strategy.extended.update(v1, fn, args=(v2, v3))
``` ```
Args: Args:
@ -990,7 +991,7 @@ class DistributionStrategyExtended(object):
if not isinstance(result, dataset_ops.Dataset): if not isinstance(result, dataset_ops.Dataset):
raise ValueError( raise ValueError(
"dataset_fn() must return a tf.data.Dataset when using a " "dataset_fn() must return a tf.data.Dataset when using a "
"DistributionStrategy.") "tf.distribute.Strategy.")
return result return result
# TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of # TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
@ -1320,7 +1321,7 @@ class DistributionStrategyExtended(object):
Args: Args:
var_list: The list of variables being optimized, needed with the var_list: The list of variables being optimized, needed with the
default `DistributionStrategy`. default `tf.distribute.Strategy`.
""" """
raise NotImplementedError("must be implemented in descendants") raise NotImplementedError("must be implemented in descendants")
@ -1374,13 +1375,18 @@ class DistributionStrategyExtended(object):
# around their model creation and graph definition. There is no # around their model creation and graph definition. There is no
# anticipated need to define descendants of _CurrentDistributionContext. # anticipated need to define descendants of _CurrentDistributionContext.
# It sets the current DistributionStrategy for purposes of # It sets the current DistributionStrategy for purposes of
# `get_distribution_strategy()` and `has_distribution_strategy()` # `get_strategy()` and `has_strategy()`
# and switches the thread mode to a "cross-replica context". # and switches the thread mode to a "cross-replica context".
@tf_export("distribute.ReplicaContext")
class ReplicaContext(object): class ReplicaContext(object):
"""DistributionStrategy API inside a `call_for_each_replica()` call.""" """`tf.distribute.Strategy` API when in a replica context.
def __init__(self, distribution_strategy, replica_id_in_sync_group): To be used inside your replicated step function, such as in a
self._distribution_strategy = distribution_strategy `tf.distribute.StrategyExtended.call_for_each_replica` call.
"""
def __init__(self, strategy, replica_id_in_sync_group):
self._distribution_strategy = strategy
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
self) self)
self._replica_id_in_sync_group = replica_id_in_sync_group self._replica_id_in_sync_group = replica_id_in_sync_group
@ -1396,23 +1402,24 @@ class ReplicaContext(object):
This allows communication and coordination when there are multiple calls This allows communication and coordination when there are multiple calls
to a model function triggered by a call to to a model function triggered by a call to
`distribution.call_for_each_replica(model_fn, ...)`. `strategy.extended.call_for_each_replica(model_fn, ...)`.
See `MirroredDistribution.call_for_each_replica()` for an explanation. See `tf.distribute.StrategyExtended.call_for_each_replica` for an
explanation.
Otherwise, this is equivalent to: If not inside a distributed scope, this is equivalent to:
``` ```
distribution = get_distribution_strategy() strategy = tf.distribute.get_strategy()
with cross-replica-context(distribution): with cross-replica-context(strategy):
return merge_fn(distribution, *args, **kwargs) return merge_fn(strategy, *args, **kwargs)
``` ```
Args: Args:
merge_fn: function that joins arguments from threads that are given as merge_fn: function that joins arguments from threads that are given as
PerReplica. It accepts `DistributionStrategy` object as the first PerReplica. It accepts `tf.distribute.Strategy` object as
argument. the first argument.
args: List or tuple with positional per-thread arguments for `merge_fn` args: List or tuple with positional per-thread arguments for `merge_fn`.
kwargs: Dict with keyword per-thread arguments for `merge_fn`. kwargs: Dict with keyword per-thread arguments for `merge_fn`.
Returns: Returns:
@ -1446,8 +1453,14 @@ class ReplicaContext(object):
return self._replica_id_in_sync_group return self._replica_id_in_sync_group
@property @property
@doc_controls.do_not_generate_docs # DEPRECATED, use `strategy`
def distribution_strategy(self): def distribution_strategy(self):
"""The current `DistributionStrategy` object.""" """DEPRECATED: use `self.stratgey` instead."""
return self._distribution_strategy
@property
def strategy(self):
"""The current `tf.distribute.Strategy` object."""
return self._distribution_strategy return self._distribution_strategy
@property @property
@ -1470,7 +1483,7 @@ class ReplicaContext(object):
class _DefaultDistributionStrategy(DistributionStrategy): class _DefaultDistributionStrategy(DistributionStrategy):
"""Default `DistributionStrategy` if none is explicitly selected.""" """Default `tf.distribute.Strategy` if none is explicitly selected."""
def __init__(self): def __init__(self):
super(_DefaultDistributionStrategy, self).__init__( super(_DefaultDistributionStrategy, self).__init__(
@ -1483,7 +1496,7 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
def _scope(self, strategy): def _scope(self, strategy):
"""Context manager setting a variable creator and `self` as current.""" """Context manager setting a variable creator and `self` as current."""
if distribution_strategy_context.has_distribution_strategy(): if distribution_strategy_context.has_distribution_strategy():
raise RuntimeError("Must not nest DistributionStrategy scopes.") raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
def creator(next_creator, *args, **kwargs): def creator(next_creator, *args, **kwargs):
_require_distribution_strategy_scope_strategy(strategy) _require_distribution_strategy_scope_strategy(strategy)
@ -1555,13 +1568,13 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
@property @property
def worker_devices(self): def worker_devices(self):
raise RuntimeError( raise RuntimeError("worker_devices() method unsupported by default "
"worker_devices() method unsupported by _DefaultDistributionStrategy.") "tf.distribute.Strategy.")
@property @property
def parameter_devices(self): def parameter_devices(self):
raise RuntimeError("parameter_devices() method unsupported by " raise RuntimeError("parameter_devices() method unsupported by default "
"_DefaultDistributionStrategy.") "tf.distribute.Strategy.")
def non_slot_devices(self, var_list): def non_slot_devices(self, var_list):
return min(var_list, key=lambda x: x.name) return min(var_list, key=lambda x: x.name)
@ -1600,8 +1613,8 @@ _original_from_proto = resource_variable_ops._from_proto_fn
def _from_proto_fn(v, import_scope=None): def _from_proto_fn(v, import_scope=None):
if distribution_strategy_context.has_distribution_strategy(): if distribution_strategy_context.has_distribution_strategy():
raise NotImplementedError( raise NotImplementedError(
"Deserialization of variables is not yet supported when using" "Deserialization of variables is not yet supported when using a "
"distributed strategies.") "tf.distribute.Strategy.")
else: else:
return _original_from_proto(v, import_scope=import_scope) return _original_from_proto(v, import_scope=import_scope)

View File

@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.util.lazy_loader import LazyLoader from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# There is a circular dependency between this and `distribute` module. So we # There is a circular dependency between this and `distribute` module. So we
@ -85,6 +86,7 @@ def _get_per_thread_mode():
# Public API for accessing the current thread mode # Public API for accessing the current thread mode
@tf_export("distribute.get_replica_context")
def get_replica_context(): def get_replica_context():
"""Returns the current `tf.distribute.ReplicaContext` or `None`. """Returns the current `tf.distribute.ReplicaContext` or `None`.
@ -95,7 +97,7 @@ def get_replica_context():
1. starts in the default (single-replica) replica context (this function 1. starts in the default (single-replica) replica context (this function
will return the default `ReplicaContext` object); will return the default `ReplicaContext` object);
2. switches to cross-replica context (in which case this will return 2. switches to cross-replica context (in which case this will return
`None`) when entering a `with DistributionStrategy.scope():` block; `None`) when entering a `with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside 3. switches to a (non-default) replica context inside
`extended.call_for_each_replica(fn, ...)`; `extended.call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
@ -103,11 +105,11 @@ def get_replica_context():
this function will return `None`). this function will return `None`).
Note that you can also go directly from step 1 to 4 to switch to a Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `DistributionStrategy`. You may cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by also switch from the cross-replica context of 4 to a replica context by
calling `extended.call_for_each_replica()`, jumping back to step 3. calling `extended.call_for_each_replica()`, jumping back to step 3.
Most `DistributionStrategy` methods may only be executed in Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context, in a replica context you should use the a cross-replica context, in a replica context you should use the
`ReplicaContext` API instead. `ReplicaContext` API instead.
@ -124,7 +126,7 @@ def get_replica_context():
def get_cross_replica_context(): def get_cross_replica_context():
"""Returns the current DistributionStrategy if in a cross-replica context. """Returns the current tf.distribute.Strategy if in a cross-replica context.
DEPRECATED: Please use `in_cross_replica_context()` and DEPRECATED: Please use `in_cross_replica_context()` and
`get_distribution_strategy()` instead. `get_distribution_strategy()` instead.
@ -133,22 +135,22 @@ def get_cross_replica_context():
1. starts in the default (single-replica) replica context; 1. starts in the default (single-replica) replica context;
2. switches to cross-replica context when entering a 2. switches to cross-replica context when entering a
`with DistributionStrategy.scope():` block; `with tf.distribute.Strategy.scope():` block;
3. switches to a (non-default) replica context inside 3. switches to a (non-default) replica context inside
`call_for_each_replica(fn, ...)`; `call_for_each_replica(fn, ...)`;
4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then 4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
inside `merge_fn` you are back in the cross-replica context. inside `merge_fn` you are back in the cross-replica context.
Note that you can also go directly from step 1 to 4 to switch to a Note that you can also go directly from step 1 to 4 to switch to a
cross-replica context for the default `DistributionStrategy`. You may cross-replica context for the default `tf.distribute.Strategy`. You may
also switch from the cross-replica context of 4 to a replica context by also switch from the cross-replica context of 4 to a replica context by
calling `call_for_each_replica()`, jumping back to step 3. calling `call_for_each_replica()`, jumping back to step 3.
Most `DistributionStrategy` methods may only be executed in Most `tf.distribute.Strategy` methods may only be executed in
a cross-replica context. a cross-replica context.
Returns: Returns:
Returns the current `DistributionStrategy` object in a cross-replica Returns the current `tf.distribute.Strategy` object in a cross-replica
context, or `None`. context, or `None`.
Exactly one of `get_replica_context()` and `get_cross_replica_context()` Exactly one of `get_replica_context()` and `get_cross_replica_context()`
@ -157,6 +159,7 @@ def get_cross_replica_context():
return _get_per_thread_mode().cross_replica_context return _get_per_thread_mode().cross_replica_context
@tf_export("distribute.in_cross_replica_context")
def in_cross_replica_context(): def in_cross_replica_context():
"""Returns True if in a cross-replica context. """Returns True if in a cross-replica context.
@ -170,31 +173,33 @@ def in_cross_replica_context():
return _get_per_thread_mode().cross_replica_context is not None return _get_per_thread_mode().cross_replica_context is not None
@tf_export("distribute.get_strategy")
def get_distribution_strategy(): def get_distribution_strategy():
"""Returns the current `DistributionStrategy` object. """Returns the current `tf.distribute.Strategy` object.
Typically only used in a cross-replica context: Typically only used in a cross-replica context:
``` ```
if tf.distribute.in_cross_replica_context(): if tf.distribute.in_cross_replica_context():
strategy = tf.distribute.get_distribution_strategy() strategy = tf.distribute.get_strategy()
... ...
``` ```
Returns: Returns:
A `DistributionStrategy` object. Inside a A `tf.distribute.Strategy` object. Inside a
`with distribution_strategy.scope()` block, it returns `with distribution_strategy.scope()` block, it returns
`distribution_strategy`, otherwise it returns the default `distribution_strategy`, otherwise it returns the default
(single-replica) `DistributionStrategy` object. (single-replica) `tf.distribute.Strategy` object.
""" """
return _get_per_thread_mode().distribution_strategy return _get_per_thread_mode().distribution_strategy
@tf_export("distribute.has_strategy")
def has_distribution_strategy(): def has_distribution_strategy():
"""Return if there is a current non-default `DistributionStrategy`. """Return if there is a current non-default `tf.distribute.Strategy`.
Returns: Returns:
True if inside a `with distribution_strategy.scope():`. True if inside a `with strategy.scope():`.
""" """
return get_distribution_strategy() is not _get_default_distribution_strategy() return get_distribution_strategy() is not _get_default_distribution_strategy()

View File

@ -0,0 +1,25 @@
path: "tensorflow.distribute.InputContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
is_instance: "<type \'object\'>"
member {
name: "input_pipeline_id"
mtype: "<type \'property\'>"
}
member {
name: "num_input_pipelines"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
}
member_method {
name: "get_per_replica_batch_size"
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,8 @@
path: "tensorflow.distribute.InputReplicationMode"
tf_class {
is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>"
}
}

View File

@ -0,0 +1,12 @@
path: "tensorflow.distribute.ReduceOp"
tf_class {
is_instance: "<enum \'ReduceOp\'>"
member {
name: "MEAN"
mtype: "<enum \'ReduceOp\'>"
}
member {
name: "SUM"
mtype: "<enum \'ReduceOp\'>"
}
}

View File

@ -0,0 +1,33 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"
mtype: "<type \'property\'>"
}
member {
name: "distribution_strategy"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "replica_id_in_sync_group"
mtype: "<type \'property\'>"
}
member {
name: "strategy"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge_call"
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
}

View File

@ -0,0 +1,81 @@
path: "tensorflow.distribute.StrategyExtended"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_between_graph"
mtype: "<type \'property\'>"
}
member {
name: "experimental_require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "experimental_should_init"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast_to"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run_steps_on_iterator"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,133 @@
path: "tensorflow.distribute.Strategy"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
is_instance: "<type \'object\'>"
member {
name: "between_graph"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_init"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "distribute_dataset"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "run_steps_on_dataset"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,47 @@
path: "tensorflow.distribute"
tf_module {
member {
name: "InputContext"
mtype: "<type \'type\'>"
}
member {
name: "InputReplicationMode"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReduceOp"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReplicaContext"
mtype: "<type \'type\'>"
}
member {
name: "Strategy"
mtype: "<type \'type\'>"
}
member {
name: "StrategyExtended"
mtype: "<type \'type\'>"
}
member_method {
name: "get_loss_reduction"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "has_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "in_cross_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -324,6 +324,10 @@ tf_module {
name: "debugging" name: "debugging"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
} }
member {
name: "distribute"
mtype: "<type \'module\'>"
}
member { member {
name: "distributions" name: "distributions"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"

View File

@ -0,0 +1,25 @@
path: "tensorflow.distribute.InputContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
is_instance: "<type \'object\'>"
member {
name: "input_pipeline_id"
mtype: "<type \'property\'>"
}
member {
name: "num_input_pipelines"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
}
member_method {
name: "get_per_replica_batch_size"
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,8 @@
path: "tensorflow.distribute.InputReplicationMode"
tf_class {
is_instance: "<enum \'InputReplicationMode\'>"
member {
name: "PER_WORKER"
mtype: "<enum \'InputReplicationMode\'>"
}
}

View File

@ -0,0 +1,12 @@
path: "tensorflow.distribute.ReduceOp"
tf_class {
is_instance: "<enum \'ReduceOp\'>"
member {
name: "MEAN"
mtype: "<enum \'ReduceOp\'>"
}
member {
name: "SUM"
mtype: "<enum \'ReduceOp\'>"
}
}

View File

@ -0,0 +1,33 @@
path: "tensorflow.distribute.ReplicaContext"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
is_instance: "<type \'object\'>"
member {
name: "devices"
mtype: "<type \'property\'>"
}
member {
name: "distribution_strategy"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "replica_id_in_sync_group"
mtype: "<type \'property\'>"
}
member {
name: "strategy"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "merge_call"
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
}

View File

@ -0,0 +1,81 @@
path: "tensorflow.distribute.StrategyExtended"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
is_instance: "<type \'object\'>"
member {
name: "experimental_between_graph"
mtype: "<type \'property\'>"
}
member {
name: "experimental_require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "experimental_should_init"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast_to"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_run_steps_on_iterator"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce_to"
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,133 @@
path: "tensorflow.distribute.Strategy"
tf_class {
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
is_instance: "<type \'object\'>"
member {
name: "between_graph"
mtype: "<type \'property\'>"
}
member {
name: "extended"
mtype: "<type \'property\'>"
}
member {
name: "num_replicas_in_sync"
mtype: "<type \'property\'>"
}
member {
name: "parameter_devices"
mtype: "<type \'property\'>"
}
member {
name: "require_static_shapes"
mtype: "<type \'property\'>"
}
member {
name: "should_checkpoint"
mtype: "<type \'property\'>"
}
member {
name: "should_init"
mtype: "<type \'property\'>"
}
member {
name: "should_save_summary"
mtype: "<type \'property\'>"
}
member {
name: "worker_devices"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "batch_reduce"
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "broadcast"
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "call_for_each_replica"
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "colocate_vars_with"
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "configure"
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "distribute_dataset"
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "experimental_initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "finalize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "group"
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "initialize"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_dataset_iterator"
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "make_input_fn_iterator"
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
}
member_method {
name: "non_slot_devices"
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "read_var"
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "reduce"
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "run_steps_on_dataset"
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "scope"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "unwrap"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "update"
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "update_non_slot"
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "value_container"
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,47 @@
path: "tensorflow.distribute"
tf_module {
member {
name: "InputContext"
mtype: "<type \'type\'>"
}
member {
name: "InputReplicationMode"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReduceOp"
mtype: "<class \'enum.EnumMeta\'>"
}
member {
name: "ReplicaContext"
mtype: "<type \'type\'>"
}
member {
name: "Strategy"
mtype: "<type \'type\'>"
}
member {
name: "StrategyExtended"
mtype: "<type \'type\'>"
}
member_method {
name: "get_loss_reduction"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "has_strategy"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "in_cross_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -180,6 +180,10 @@ tf_module {
name: "debugging" name: "debugging"
mtype: "<type \'module\'>" mtype: "<type \'module\'>"
} }
member {
name: "distribute"
mtype: "<type \'module\'>"
}
member { member {
name: "double" name: "double"
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>" mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"