Create tf.distribute namespace with the base classes and standard APIs
for distribution strategies. Strategy implementations will be in a future change. Also rename DistributionStrategy to tf.distribute.Strategy, and other changes to match. RELNOTES: Expose tf.distribute.Strategy as the new name for tf.contrib.distribute.DistributionStrategy. PiperOrigin-RevId: 222097121
This commit is contained in:
parent
5c60fb7e9b
commit
a4bcedd676
tensorflow
python
distribute
tools/api/generator
training
tools/api/golden
v1
tensorflow.distribute.-input-context.pbtxttensorflow.distribute.-input-replication-mode.pbtxttensorflow.distribute.-reduce-op.pbtxttensorflow.distribute.-replica-context.pbtxttensorflow.distribute.-strategy-extended.pbtxttensorflow.distribute.-strategy.pbtxttensorflow.distribute.pbtxttensorflow.pbtxt
v2
tensorflow.distribute.-input-context.pbtxttensorflow.distribute.-input-replication-mode.pbtxttensorflow.distribute.-reduce-op.pbtxttensorflow.distribute.-replica-context.pbtxttensorflow.distribute.-strategy-extended.pbtxttensorflow.distribute.-strategy.pbtxttensorflow.distribute.pbtxttensorflow.pbtxt
@ -216,7 +216,10 @@ py_library(
|
||||
py_library(
|
||||
name = "reduce_util",
|
||||
srcs = ["reduce_util.py"],
|
||||
deps = [],
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -21,9 +21,10 @@ from __future__ import print_function
|
||||
import enum
|
||||
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# TODO(priyag): Add this to tf.distribute namespace when it exists.
|
||||
@tf_export("distribute.ReduceOp")
|
||||
class ReduceOp(enum.Enum):
|
||||
"""Indicates how a set of values should be reduced.
|
||||
|
||||
|
@ -9,6 +9,7 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"data/__init__.py",
|
||||
"data/experimental/__init__.py",
|
||||
"debugging/__init__.py",
|
||||
"distribute/__init__.py",
|
||||
"dtypes/__init__.py",
|
||||
"errors/__init__.py",
|
||||
"experimental/__init__.py",
|
||||
|
@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
||||
"data/__init__.py",
|
||||
"data/experimental/__init__.py",
|
||||
"debugging/__init__.py",
|
||||
"distribute/__init__.py",
|
||||
"distributions/__init__.py",
|
||||
"dtypes/__init__.py",
|
||||
"errors/__init__.py",
|
||||
|
@ -35,10 +35,11 @@ DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
|
||||
|
||||
_TENSORFLOW_DOC_SOURCES = {
|
||||
'app': DocSource(docstring_module_name='platform.app'),
|
||||
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
|
||||
'compat': DocSource(docstring_module_name='util.compat'),
|
||||
'distribute': DocSource(docstring_module_name='training.distribute'),
|
||||
'distributions': DocSource(
|
||||
docstring_module_name='ops.distributions.distributions'),
|
||||
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
|
||||
'errors': DocSource(docstring_module_name='framework.errors'),
|
||||
'gfile': DocSource(docstring_module_name='platform.gfile'),
|
||||
'graph_util': DocSource(docstring_module_name='framework.graph_util'),
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Class DistributionStrategy, ReplicaContext, and supporting APIs."""
|
||||
"""Library for running a computation across multiple devices."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -38,19 +38,19 @@ from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training import device_util
|
||||
from tensorflow.python.training import distribution_strategy_context
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Context tracking whether in a distribution.update() or .update_non_slot()
|
||||
# call.
|
||||
# Context tracking whether in a strategy.update() or .update_non_slot() call.
|
||||
|
||||
|
||||
_update_device = threading.local()
|
||||
|
||||
|
||||
def get_update_device():
|
||||
"""Get the current device if in a `DistributionStrategy.update()` call."""
|
||||
"""Get the current device if in a `tf.distribute.Strategy.update()` call."""
|
||||
try:
|
||||
return _update_device.current
|
||||
except AttributeError:
|
||||
@ -77,8 +77,9 @@ class UpdateContext(object):
|
||||
# Public utility functions.
|
||||
|
||||
|
||||
@tf_export("distribute.get_loss_reduction")
|
||||
def get_loss_reduction():
|
||||
"""Reduce op corresponding to the last loss reduction."""
|
||||
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
|
||||
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
|
||||
if loss_reduction == losses_impl.Reduction.SUM:
|
||||
return reduce_util.ReduceOp.SUM
|
||||
@ -95,25 +96,25 @@ def _require_cross_replica_context_extended(extended):
|
||||
cross_replica = context.cross_replica_context
|
||||
if cross_replica is not None and cross_replica.extended is extended:
|
||||
return
|
||||
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access
|
||||
strategy = extended._container_strategy() # pylint: disable=protected-access
|
||||
# We have an error to report, figure out the right message.
|
||||
if context.distribution_strategy is not distribution_strategy:
|
||||
_wrong_distribution_strategy_scope(distribution_strategy, context)
|
||||
if context.distribution_strategy is not strategy:
|
||||
_wrong_strategy_scope(strategy, context)
|
||||
assert cross_replica is None
|
||||
raise RuntimeError("Method requires being in cross-replica context, use "
|
||||
"get_replica_context().merge_call()")
|
||||
|
||||
|
||||
def _wrong_distribution_strategy_scope(distribution_strategy, context):
|
||||
def _wrong_strategy_scope(strategy, context):
|
||||
# Figure out the right error message.
|
||||
if not distribution_strategy_context.has_distribution_strategy():
|
||||
raise RuntimeError(
|
||||
'Need to be inside "with distribution_strategy.scope()" for %s' %
|
||||
(distribution_strategy,))
|
||||
'Need to be inside "with strategy.scope()" for %s' %
|
||||
(strategy,))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Mixing different DistributionStrategy objects: %s is not %s" %
|
||||
(context.distribution_strategy, distribution_strategy))
|
||||
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
||||
(context.distribution_strategy, strategy))
|
||||
|
||||
|
||||
def require_replica_context(replica_ctx):
|
||||
@ -124,18 +125,18 @@ def require_replica_context(replica_ctx):
|
||||
if context.replica_context is None:
|
||||
raise RuntimeError("Need to be inside `call_for_each_replica()`")
|
||||
if context.distribution_strategy is replica_ctx.distribution_strategy:
|
||||
# Two different ReplicaContexts with the same DistributionStrategy.
|
||||
# Two different ReplicaContexts with the same tf.distribute.Strategy.
|
||||
raise RuntimeError("Mismatching ReplicaContext.")
|
||||
raise RuntimeError(
|
||||
"Mismatching DistributionStrategy objects: %s is not %s." %
|
||||
"Mismatching tf.distribute.Strategy objects: %s is not %s." %
|
||||
(context.distribution_strategy, replica_ctx.distribution_strategy))
|
||||
|
||||
|
||||
def _require_distribution_strategy_scope_strategy(distribution_strategy):
|
||||
"""Verify in a `distribution_strategy.scope()` in this thread."""
|
||||
def _require_distribution_strategy_scope_strategy(strategy):
|
||||
"""Verify in a `strategy.scope()` in this thread."""
|
||||
context = _get_per_thread_mode()
|
||||
if context.distribution_strategy is distribution_strategy: return
|
||||
_wrong_distribution_strategy_scope(distribution_strategy, context)
|
||||
if context.distribution_strategy is strategy: return
|
||||
_wrong_strategy_scope(strategy, context)
|
||||
|
||||
|
||||
def _require_distribution_strategy_scope_extended(extended):
|
||||
@ -143,8 +144,8 @@ def _require_distribution_strategy_scope_extended(extended):
|
||||
context = _get_per_thread_mode()
|
||||
if context.distribution_strategy.extended is extended: return
|
||||
# Report error.
|
||||
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access
|
||||
_wrong_distribution_strategy_scope(distribution_strategy, context)
|
||||
strategy = extended._container_strategy() # pylint: disable=protected-access
|
||||
_wrong_strategy_scope(strategy, context)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
@ -153,15 +154,18 @@ def _require_distribution_strategy_scope_extended(extended):
|
||||
|
||||
|
||||
class _CurrentDistributionContext(object):
|
||||
"""Context manager for setting the `DistributionStrategy` and var creator."""
|
||||
"""Context manager setting the current `tf.distribute.Strategy`.
|
||||
|
||||
Also: overrides the variable creator and optionally the current device.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
distribution_strategy,
|
||||
strategy,
|
||||
var_creator_scope,
|
||||
var_scope=None,
|
||||
default_device=None):
|
||||
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
|
||||
distribution_strategy)
|
||||
strategy)
|
||||
self._var_creator_scope = var_creator_scope
|
||||
self._var_scope = var_scope
|
||||
if default_device:
|
||||
@ -190,8 +194,8 @@ class _CurrentDistributionContext(object):
|
||||
class _SameScopeAgainContext(object):
|
||||
"""Trivial context manager when you are already in `scope()`."""
|
||||
|
||||
def __init__(self, distribution_strategy):
|
||||
self._distribution_strategy = distribution_strategy
|
||||
def __init__(self, strategy):
|
||||
self._distribution_strategy = strategy
|
||||
|
||||
def __enter__(self):
|
||||
return self._distribution_strategy
|
||||
@ -201,6 +205,7 @@ class _SameScopeAgainContext(object):
|
||||
|
||||
|
||||
# TODO(yuefengz): add more replication modes.
|
||||
@tf_export("distribute.InputReplicationMode")
|
||||
class InputReplicationMode(enum.Enum):
|
||||
"""Replication mode for input function."""
|
||||
|
||||
@ -211,6 +216,7 @@ class InputReplicationMode(enum.Enum):
|
||||
PER_WORKER = "PER_WORKER"
|
||||
|
||||
|
||||
@tf_export("distribute.InputContext")
|
||||
class InputContext(object):
|
||||
"""A class wrapping information needed by an input function.
|
||||
|
||||
@ -278,6 +284,7 @@ class InputContext(object):
|
||||
# Base classes for all distribution strategies.
|
||||
|
||||
|
||||
@tf_export("distribute.Strategy")
|
||||
class DistributionStrategy(object):
|
||||
"""A list of devices with a state & compute distribution policy.
|
||||
|
||||
@ -301,14 +308,14 @@ class DistributionStrategy(object):
|
||||
|
||||
@property
|
||||
def extended(self):
|
||||
"""`tf.contrib.distribute.DistributionStrategyExtended` with new methods."""
|
||||
"""`tf.distribute.StrategyExtended` with additional methods."""
|
||||
return self._extended
|
||||
|
||||
def scope(self):
|
||||
"""Returns a context manager selecting this DistributionStrategy as current.
|
||||
"""Returns a context manager selecting this Strategy as current.
|
||||
|
||||
Inside a `with distribution_strategy.scope():` code block, this thread
|
||||
will use a variable creator set by `distribution_strategy`, and will
|
||||
Inside a `with strategy.scope():` code block, this thread
|
||||
will use a variable creator set by `strategy`, and will
|
||||
enter its "cross-replica context".
|
||||
|
||||
Returns:
|
||||
@ -330,20 +337,20 @@ class DistributionStrategy(object):
|
||||
def distribute_dataset(self, dataset_fn):
|
||||
"""Return a `dataset` split across all replicas. DEPRECATED.
|
||||
|
||||
DEPRECATED: Please use `make_dataset_iterator()` or
|
||||
`make_input_fn_iterator()` instead.
|
||||
DEPRECATED: Please use `make_dataset_iterator` or
|
||||
`make_input_fn_iterator` instead.
|
||||
|
||||
Suitable for providing input to for `extended.call_for_each_replica()` by
|
||||
Suitable for providing input to `extended.call_for_each_replica()` by
|
||||
creating an iterator:
|
||||
|
||||
```
|
||||
def dataset_fn():
|
||||
return tf.data.Dataset.from_tensors([[1.]]).repeat()
|
||||
|
||||
with distribution_strategy.scope():
|
||||
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
|
||||
with strategy.scope():
|
||||
distributed_dataset = strategy.distribute_dataset(dataset_fn)
|
||||
iterator = distributed_dataset.make_initializable_iterator()
|
||||
replica_results = distribution_strategy.extended.call_for_each_replica(
|
||||
replica_results = strategy.extended.call_for_each_replica(
|
||||
replica_fn, args=(iterator.get_next(),))
|
||||
```
|
||||
|
||||
@ -374,8 +381,8 @@ class DistributionStrategy(object):
|
||||
replicas.
|
||||
|
||||
Returns:
|
||||
An `InputIterator` which returns inputs for each step of the computation.
|
||||
User should call `initialize` on the returned iterator.
|
||||
An `tf.distribute.InputIterator` which returns inputs for each step of the
|
||||
computation. User should call `initialize` on the returned iterator.
|
||||
"""
|
||||
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
|
||||
|
||||
@ -384,26 +391,26 @@ class DistributionStrategy(object):
|
||||
replication_mode=InputReplicationMode.PER_WORKER):
|
||||
"""Returns an iterator split across replicas created from an input function.
|
||||
|
||||
The `input_fn` should take an `InputContext` object where information about
|
||||
input sharding can be accessed:
|
||||
The `input_fn` should take an `tf.distribute.InputContext` object where
|
||||
information about input sharding can be accessed:
|
||||
|
||||
```
|
||||
def input_fn(input_context):
|
||||
d = tf.data.Dataset.from_tensors([[1.]]).repeat()
|
||||
return d.shard(input_context.num_input_pipelines,
|
||||
input_context.input_pipeline_id)
|
||||
with distribution_strategy.scope():
|
||||
iterator = distribution_strategy.make_input_fn_iterator(
|
||||
with strategy.scope():
|
||||
iterator = strategy.make_input_fn_iterator(
|
||||
input_fn)
|
||||
replica_results = distribution_strategy.call_for_each_replica(
|
||||
replica_results = strategy.extended.call_for_each_replica(
|
||||
replica_fn, iterator.get_next())
|
||||
```
|
||||
|
||||
Args:
|
||||
input_fn: A function that returns a `tf.data.Dataset`. This function is
|
||||
expected to take an `InputContext` object.
|
||||
replication_mode: an enum value of `InputReplicationMode`. Only
|
||||
`PER_WORKER` is supported currently.
|
||||
expected to take an `tf.distribute.InputContext` object.
|
||||
replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
|
||||
Only `PER_WORKER` is supported currently.
|
||||
|
||||
Returns:
|
||||
An iterator object that can be initialized and fetched next element.
|
||||
@ -544,7 +551,7 @@ class DistributionStrategy(object):
|
||||
|
||||
Args:
|
||||
value: A value returned by `extended.call_for_each_replica()` or a
|
||||
variable created in `scope()`.
|
||||
variable created in `scope`.
|
||||
|
||||
Returns:
|
||||
A list of values contained in `value`. If `value` represents a single
|
||||
@ -638,17 +645,19 @@ class DistributionStrategy(object):
|
||||
raise RuntimeError("Must only deepcopy DistributionStrategy.")
|
||||
|
||||
|
||||
@tf_export("distribute.StrategyExtended")
|
||||
class DistributionStrategyExtended(object):
|
||||
"""Additional APIs for algorithms that need to be distribution-aware.
|
||||
|
||||
The intent is that you can write an algorithm in a stylized way and
|
||||
it will be usable with a variety of different `DistributionStrategy`
|
||||
it will be usable with a variety of different
|
||||
`tf.distribute.Strategy`
|
||||
implementations. Each descendant will implement a different strategy
|
||||
for distributing the algorithm across multiple devices/machines.
|
||||
Furthermore, these changes can be hidden inside the specific layers
|
||||
and other library classes that need special treatment to run in a
|
||||
distributed setting, so that most users' model definition code can
|
||||
run unchanged. The `DistributionStrategy` API works the same way
|
||||
run unchanged. The `tf.distribute.Strategy` API works the same way
|
||||
with eager and graph execution.
|
||||
|
||||
First let's introduce a few high-level concepts:
|
||||
@ -696,42 +705,33 @@ class DistributionStrategyExtended(object):
|
||||
|
||||
We have then a few approaches we want to support:
|
||||
|
||||
* Code written (as if) with no knowledge of class `DistributionStrategy`.
|
||||
* Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
|
||||
This code should work as before, even if some of the layers, etc.
|
||||
used by that code are written to be distribution-aware. This is done
|
||||
by having a default `DistributionStrategy` that gives ordinary behavior,
|
||||
by having a default `tf.distribute.Strategy` that gives ordinary behavior,
|
||||
and by default being in a single replica context.
|
||||
* Ordinary model code that you want to run using a specific
|
||||
`DistributionStrategy`. This can be as simple as:
|
||||
`tf.distribute.Strategy`. This can be as simple as:
|
||||
|
||||
```
|
||||
with my_distribution.scope():
|
||||
iterator = my_distribution.distribute_dataset(
|
||||
dataset).make_one_shot_iterator()
|
||||
replica_train_ops = my_distribution.extended.call_for_each_replica(
|
||||
with my_strategy.scope():
|
||||
iterator = my_strategy.make_dataset_iterator(dataset)
|
||||
session.run(iterator.initialize())
|
||||
replica_train_ops = my_strategy.extended.call_for_each_replica(
|
||||
replica_fn, args=(iterator.get_next(),))
|
||||
train_op = tf.group(my_distribution.unwrap(replica_train_ops))
|
||||
train_op = my_strategy.group(replica_train_ops)
|
||||
```
|
||||
|
||||
This takes an ordinary `dataset` and `replica_fn` and runs it
|
||||
distributed using a particular `DistributionStrategy` in
|
||||
`my_distribution`. Any variables created in `replica_fn` are created
|
||||
using `my_distribution`'s policy, and library functions called by
|
||||
distributed using a particular `tf.distribute.Strategy` in
|
||||
`my_strategy`. Any variables created in `replica_fn` are created
|
||||
using `my_strategy`'s policy, and library functions called by
|
||||
`replica_fn` can use the `get_replica_context()` API to get enhanced
|
||||
behavior in this case.
|
||||
|
||||
You can also create an initializable iterator instead of a one-shot
|
||||
iterator. In that case, you will need to ensure that you initialize the
|
||||
iterator before calling get_next.
|
||||
```
|
||||
iterator = my_distribution.distribute_dataset(
|
||||
dataset).make_initializable_iterator())
|
||||
session.run(iterator.initializer)
|
||||
```
|
||||
|
||||
* If you want to write a distributed algorithm, you may use any of
|
||||
the `DistributionStrategy` APIs inside a
|
||||
`with my_distribution.scope():` block of code.
|
||||
the `tf.distribute.Strategy` APIs inside a
|
||||
`with my_strategy.scope():` block of code.
|
||||
|
||||
Lower-level concepts:
|
||||
|
||||
@ -758,7 +758,7 @@ class DistributionStrategyExtended(object):
|
||||
* Replica context vs. Cross-replica context: _replica context_ is when we
|
||||
are in some function that is being called once for each replica.
|
||||
Otherwise we are in cross-replica context, which is useful for
|
||||
calling `DistributionStrategy` methods which operate across the
|
||||
calling `tf.distribute.Strategy` methods which operate across the
|
||||
replicas (like `reduce_to()`). By default you start in a replica context
|
||||
(the default "single replica context") and then some methods can
|
||||
switch you back and forth, as described below.
|
||||
@ -778,7 +778,7 @@ class DistributionStrategyExtended(object):
|
||||
pick a consistent set of devices to pass to both
|
||||
`colocate_vars_with()` and `update_non_slot()`.
|
||||
|
||||
When using a `DistributionStrategy`, we have a new type dimension
|
||||
When using a `tf.distribute.Strategy`, we have a new type dimension
|
||||
called _locality_ that says what values are compatible with which
|
||||
APIs:
|
||||
|
||||
@ -856,19 +856,20 @@ class DistributionStrategyExtended(object):
|
||||
input and destination.
|
||||
|
||||
Layers should expect to be called in a replica context, and can use
|
||||
the `get_replica_context()` function to get a `ReplicaContext` object. The
|
||||
the `tf.distribute.get_replica_context` function to get a
|
||||
`tf.distribute.ReplicaContext` object. The
|
||||
`ReplicaContext` object has a `merge_call()` method for entering
|
||||
cross-replica context where you can use `reduce_to()` (or
|
||||
`batch_reduce_to()`) and then optionally `update()` to update state.
|
||||
|
||||
You may use this API whether or not a `DistributionStrategy` is
|
||||
You may use this API whether or not a `tf.distribute.Strategy` is
|
||||
being used, since there is a default implementation of
|
||||
`ReplicaContext` and `DistributionStrategy`.
|
||||
`ReplicaContext` and `tf.distribute.Strategy`.
|
||||
|
||||
NOTE for new `DistributionStrategy` implementations: Please put all logic
|
||||
in a subclass of `DistributionStrategyExtended`. The only code needed for
|
||||
the `DistributionStrategy` subclass is for instantiating your subclass of
|
||||
`DistributionStrategyExtended` in the `__init__` method.
|
||||
NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
|
||||
in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
|
||||
the `tf.distribute.Strategy` subclass is for instantiating your subclass of
|
||||
`tf.distribute.StrategyExtended` in the `__init__` method.
|
||||
"""
|
||||
|
||||
def __init__(self, container_strategy):
|
||||
@ -908,7 +909,7 @@ class DistributionStrategyExtended(object):
|
||||
if kwargs.pop("partitioner", None) is not None:
|
||||
tf_logging.log_first_n(
|
||||
tf_logging.WARN, "Partitioned variables are disabled when using "
|
||||
"current DistributionStrategy.", 1)
|
||||
"current tf.distribute.Strategy.", 1)
|
||||
return getter(*args, **kwargs)
|
||||
|
||||
return _CurrentDistributionContext(
|
||||
@ -932,7 +933,7 @@ class DistributionStrategyExtended(object):
|
||||
(read-only) value of any other variable.
|
||||
|
||||
Args:
|
||||
v: A variable allocated within the scope of this `DistributionStrategy`.
|
||||
v: A variable allocated within the scope of this `tf.distribute.Strategy`.
|
||||
|
||||
Returns:
|
||||
A tensor representing the value of `v`, aggregated across replicas if
|
||||
@ -953,9 +954,9 @@ class DistributionStrategyExtended(object):
|
||||
Example usage:
|
||||
|
||||
```
|
||||
with distribution_strategy.scope():
|
||||
with strategy.scope():
|
||||
var1 = tf.get_variable(...)
|
||||
with distribution_strategy.colocate_vars_with(v1):
|
||||
with strategy.extended.colocate_vars_with(v1):
|
||||
# var2 and var3 will be created on the same device(s) as var1
|
||||
var2 = tf.get_variable(...)
|
||||
var3 = tf.get_variable(...)
|
||||
@ -964,7 +965,7 @@ class DistributionStrategyExtended(object):
|
||||
# operates on v1 from var1, v2 from var2, and v3 from var3
|
||||
|
||||
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
|
||||
distribution_strategy.update(v1, fn, args=(v2, v3))
|
||||
strategy.extended.update(v1, fn, args=(v2, v3))
|
||||
```
|
||||
|
||||
Args:
|
||||
@ -990,7 +991,7 @@ class DistributionStrategyExtended(object):
|
||||
if not isinstance(result, dataset_ops.Dataset):
|
||||
raise ValueError(
|
||||
"dataset_fn() must return a tf.data.Dataset when using a "
|
||||
"DistributionStrategy.")
|
||||
"tf.distribute.Strategy.")
|
||||
return result
|
||||
|
||||
# TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
|
||||
@ -1320,7 +1321,7 @@ class DistributionStrategyExtended(object):
|
||||
|
||||
Args:
|
||||
var_list: The list of variables being optimized, needed with the
|
||||
default `DistributionStrategy`.
|
||||
default `tf.distribute.Strategy`.
|
||||
"""
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
@ -1374,13 +1375,18 @@ class DistributionStrategyExtended(object):
|
||||
# around their model creation and graph definition. There is no
|
||||
# anticipated need to define descendants of _CurrentDistributionContext.
|
||||
# It sets the current DistributionStrategy for purposes of
|
||||
# `get_distribution_strategy()` and `has_distribution_strategy()`
|
||||
# `get_strategy()` and `has_strategy()`
|
||||
# and switches the thread mode to a "cross-replica context".
|
||||
@tf_export("distribute.ReplicaContext")
|
||||
class ReplicaContext(object):
|
||||
"""DistributionStrategy API inside a `call_for_each_replica()` call."""
|
||||
"""`tf.distribute.Strategy` API when in a replica context.
|
||||
|
||||
def __init__(self, distribution_strategy, replica_id_in_sync_group):
|
||||
self._distribution_strategy = distribution_strategy
|
||||
To be used inside your replicated step function, such as in a
|
||||
`tf.distribute.StrategyExtended.call_for_each_replica` call.
|
||||
"""
|
||||
|
||||
def __init__(self, strategy, replica_id_in_sync_group):
|
||||
self._distribution_strategy = strategy
|
||||
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
|
||||
self)
|
||||
self._replica_id_in_sync_group = replica_id_in_sync_group
|
||||
@ -1396,23 +1402,24 @@ class ReplicaContext(object):
|
||||
|
||||
This allows communication and coordination when there are multiple calls
|
||||
to a model function triggered by a call to
|
||||
`distribution.call_for_each_replica(model_fn, ...)`.
|
||||
`strategy.extended.call_for_each_replica(model_fn, ...)`.
|
||||
|
||||
See `MirroredDistribution.call_for_each_replica()` for an explanation.
|
||||
See `tf.distribute.StrategyExtended.call_for_each_replica` for an
|
||||
explanation.
|
||||
|
||||
Otherwise, this is equivalent to:
|
||||
If not inside a distributed scope, this is equivalent to:
|
||||
|
||||
```
|
||||
distribution = get_distribution_strategy()
|
||||
with cross-replica-context(distribution):
|
||||
return merge_fn(distribution, *args, **kwargs)
|
||||
strategy = tf.distribute.get_strategy()
|
||||
with cross-replica-context(strategy):
|
||||
return merge_fn(strategy, *args, **kwargs)
|
||||
```
|
||||
|
||||
Args:
|
||||
merge_fn: function that joins arguments from threads that are given as
|
||||
PerReplica. It accepts `DistributionStrategy` object as the first
|
||||
argument.
|
||||
args: List or tuple with positional per-thread arguments for `merge_fn`
|
||||
PerReplica. It accepts `tf.distribute.Strategy` object as
|
||||
the first argument.
|
||||
args: List or tuple with positional per-thread arguments for `merge_fn`.
|
||||
kwargs: Dict with keyword per-thread arguments for `merge_fn`.
|
||||
|
||||
Returns:
|
||||
@ -1446,8 +1453,14 @@ class ReplicaContext(object):
|
||||
return self._replica_id_in_sync_group
|
||||
|
||||
@property
|
||||
@doc_controls.do_not_generate_docs # DEPRECATED, use `strategy`
|
||||
def distribution_strategy(self):
|
||||
"""The current `DistributionStrategy` object."""
|
||||
"""DEPRECATED: use `self.stratgey` instead."""
|
||||
return self._distribution_strategy
|
||||
|
||||
@property
|
||||
def strategy(self):
|
||||
"""The current `tf.distribute.Strategy` object."""
|
||||
return self._distribution_strategy
|
||||
|
||||
@property
|
||||
@ -1470,7 +1483,7 @@ class ReplicaContext(object):
|
||||
|
||||
|
||||
class _DefaultDistributionStrategy(DistributionStrategy):
|
||||
"""Default `DistributionStrategy` if none is explicitly selected."""
|
||||
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
|
||||
|
||||
def __init__(self):
|
||||
super(_DefaultDistributionStrategy, self).__init__(
|
||||
@ -1483,7 +1496,7 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
|
||||
def _scope(self, strategy):
|
||||
"""Context manager setting a variable creator and `self` as current."""
|
||||
if distribution_strategy_context.has_distribution_strategy():
|
||||
raise RuntimeError("Must not nest DistributionStrategy scopes.")
|
||||
raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
|
||||
|
||||
def creator(next_creator, *args, **kwargs):
|
||||
_require_distribution_strategy_scope_strategy(strategy)
|
||||
@ -1555,13 +1568,13 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
|
||||
|
||||
@property
|
||||
def worker_devices(self):
|
||||
raise RuntimeError(
|
||||
"worker_devices() method unsupported by _DefaultDistributionStrategy.")
|
||||
raise RuntimeError("worker_devices() method unsupported by default "
|
||||
"tf.distribute.Strategy.")
|
||||
|
||||
@property
|
||||
def parameter_devices(self):
|
||||
raise RuntimeError("parameter_devices() method unsupported by "
|
||||
"_DefaultDistributionStrategy.")
|
||||
raise RuntimeError("parameter_devices() method unsupported by default "
|
||||
"tf.distribute.Strategy.")
|
||||
|
||||
def non_slot_devices(self, var_list):
|
||||
return min(var_list, key=lambda x: x.name)
|
||||
@ -1600,8 +1613,8 @@ _original_from_proto = resource_variable_ops._from_proto_fn
|
||||
def _from_proto_fn(v, import_scope=None):
|
||||
if distribution_strategy_context.has_distribution_strategy():
|
||||
raise NotImplementedError(
|
||||
"Deserialization of variables is not yet supported when using"
|
||||
"distributed strategies.")
|
||||
"Deserialization of variables is not yet supported when using a "
|
||||
"tf.distribute.Strategy.")
|
||||
else:
|
||||
return _original_from_proto(v, import_scope=import_scope)
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
# There is a circular dependency between this and `distribute` module. So we
|
||||
@ -85,6 +86,7 @@ def _get_per_thread_mode():
|
||||
# Public API for accessing the current thread mode
|
||||
|
||||
|
||||
@tf_export("distribute.get_replica_context")
|
||||
def get_replica_context():
|
||||
"""Returns the current `tf.distribute.ReplicaContext` or `None`.
|
||||
|
||||
@ -95,7 +97,7 @@ def get_replica_context():
|
||||
1. starts in the default (single-replica) replica context (this function
|
||||
will return the default `ReplicaContext` object);
|
||||
2. switches to cross-replica context (in which case this will return
|
||||
`None`) when entering a `with DistributionStrategy.scope():` block;
|
||||
`None`) when entering a `with tf.distribute.Strategy.scope():` block;
|
||||
3. switches to a (non-default) replica context inside
|
||||
`extended.call_for_each_replica(fn, ...)`;
|
||||
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
|
||||
@ -103,11 +105,11 @@ def get_replica_context():
|
||||
this function will return `None`).
|
||||
|
||||
Note that you can also go directly from step 1 to 4 to switch to a
|
||||
cross-replica context for the default `DistributionStrategy`. You may
|
||||
cross-replica context for the default `tf.distribute.Strategy`. You may
|
||||
also switch from the cross-replica context of 4 to a replica context by
|
||||
calling `extended.call_for_each_replica()`, jumping back to step 3.
|
||||
|
||||
Most `DistributionStrategy` methods may only be executed in
|
||||
Most `tf.distribute.Strategy` methods may only be executed in
|
||||
a cross-replica context, in a replica context you should use the
|
||||
`ReplicaContext` API instead.
|
||||
|
||||
@ -124,7 +126,7 @@ def get_replica_context():
|
||||
|
||||
|
||||
def get_cross_replica_context():
|
||||
"""Returns the current DistributionStrategy if in a cross-replica context.
|
||||
"""Returns the current tf.distribute.Strategy if in a cross-replica context.
|
||||
|
||||
DEPRECATED: Please use `in_cross_replica_context()` and
|
||||
`get_distribution_strategy()` instead.
|
||||
@ -133,22 +135,22 @@ def get_cross_replica_context():
|
||||
|
||||
1. starts in the default (single-replica) replica context;
|
||||
2. switches to cross-replica context when entering a
|
||||
`with DistributionStrategy.scope():` block;
|
||||
`with tf.distribute.Strategy.scope():` block;
|
||||
3. switches to a (non-default) replica context inside
|
||||
`call_for_each_replica(fn, ...)`;
|
||||
4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
|
||||
inside `merge_fn` you are back in the cross-replica context.
|
||||
|
||||
Note that you can also go directly from step 1 to 4 to switch to a
|
||||
cross-replica context for the default `DistributionStrategy`. You may
|
||||
cross-replica context for the default `tf.distribute.Strategy`. You may
|
||||
also switch from the cross-replica context of 4 to a replica context by
|
||||
calling `call_for_each_replica()`, jumping back to step 3.
|
||||
|
||||
Most `DistributionStrategy` methods may only be executed in
|
||||
Most `tf.distribute.Strategy` methods may only be executed in
|
||||
a cross-replica context.
|
||||
|
||||
Returns:
|
||||
Returns the current `DistributionStrategy` object in a cross-replica
|
||||
Returns the current `tf.distribute.Strategy` object in a cross-replica
|
||||
context, or `None`.
|
||||
|
||||
Exactly one of `get_replica_context()` and `get_cross_replica_context()`
|
||||
@ -157,6 +159,7 @@ def get_cross_replica_context():
|
||||
return _get_per_thread_mode().cross_replica_context
|
||||
|
||||
|
||||
@tf_export("distribute.in_cross_replica_context")
|
||||
def in_cross_replica_context():
|
||||
"""Returns True if in a cross-replica context.
|
||||
|
||||
@ -170,31 +173,33 @@ def in_cross_replica_context():
|
||||
return _get_per_thread_mode().cross_replica_context is not None
|
||||
|
||||
|
||||
@tf_export("distribute.get_strategy")
|
||||
def get_distribution_strategy():
|
||||
"""Returns the current `DistributionStrategy` object.
|
||||
"""Returns the current `tf.distribute.Strategy` object.
|
||||
|
||||
Typically only used in a cross-replica context:
|
||||
|
||||
```
|
||||
if tf.distribute.in_cross_replica_context():
|
||||
strategy = tf.distribute.get_distribution_strategy()
|
||||
strategy = tf.distribute.get_strategy()
|
||||
...
|
||||
```
|
||||
|
||||
Returns:
|
||||
A `DistributionStrategy` object. Inside a
|
||||
A `tf.distribute.Strategy` object. Inside a
|
||||
`with distribution_strategy.scope()` block, it returns
|
||||
`distribution_strategy`, otherwise it returns the default
|
||||
(single-replica) `DistributionStrategy` object.
|
||||
(single-replica) `tf.distribute.Strategy` object.
|
||||
"""
|
||||
return _get_per_thread_mode().distribution_strategy
|
||||
|
||||
|
||||
@tf_export("distribute.has_strategy")
|
||||
def has_distribution_strategy():
|
||||
"""Return if there is a current non-default `DistributionStrategy`.
|
||||
"""Return if there is a current non-default `tf.distribute.Strategy`.
|
||||
|
||||
Returns:
|
||||
True if inside a `with distribution_strategy.scope():`.
|
||||
True if inside a `with strategy.scope():`.
|
||||
"""
|
||||
return get_distribution_strategy() is not _get_default_distribution_strategy()
|
||||
|
||||
|
@ -0,0 +1,25 @@
|
||||
path: "tensorflow.distribute.InputContext"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "input_pipeline_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_input_pipelines"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_per_replica_batch_size"
|
||||
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
path: "tensorflow.distribute.InputReplicationMode"
|
||||
tf_class {
|
||||
is_instance: "<enum \'InputReplicationMode\'>"
|
||||
member {
|
||||
name: "PER_WORKER"
|
||||
mtype: "<enum \'InputReplicationMode\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.distribute.ReduceOp"
|
||||
tf_class {
|
||||
is_instance: "<enum \'ReduceOp\'>"
|
||||
member {
|
||||
name: "MEAN"
|
||||
mtype: "<enum \'ReduceOp\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUM"
|
||||
mtype: "<enum \'ReduceOp\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
path: "tensorflow.distribute.ReplicaContext"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "distribution_strategy"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "replica_id_in_sync_group"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "strategy"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
path: "tensorflow.distribute.StrategyExtended"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "experimental_between_graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_require_static_shapes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_should_init"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "parameter_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_save_summary"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "worker_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast_to"
|
||||
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call_for_each_replica"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_steps_on_iterator"
|
||||
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "non_slot_devices"
|
||||
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "read_var"
|
||||
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "update_non_slot"
|
||||
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "value_container"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,133 @@
|
||||
path: "tensorflow.distribute.Strategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "between_graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "parameter_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "require_static_shapes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_init"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_save_summary"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "worker_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "call_for_each_replica"
|
||||
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_dataset_iterator"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_input_fn_iterator"
|
||||
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "non_slot_devices"
|
||||
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "read_var"
|
||||
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "run_steps_on_dataset"
|
||||
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "scope"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unwrap"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "update_non_slot"
|
||||
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "value_container"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
47
tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt
Normal file
47
tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt
Normal file
@ -0,0 +1,47 @@
|
||||
path: "tensorflow.distribute"
|
||||
tf_module {
|
||||
member {
|
||||
name: "InputContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "InputReplicationMode"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "ReduceOp"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "ReplicaContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Strategy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "StrategyExtended"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_loss_reduction"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_replica_context"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_strategy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "has_strategy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "in_cross_replica_context"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -324,6 +324,10 @@ tf_module {
|
||||
name: "debugging"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "distribute"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "distributions"
|
||||
mtype: "<type \'module\'>"
|
||||
|
@ -0,0 +1,25 @@
|
||||
path: "tensorflow.distribute.InputContext"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "input_pipeline_id"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_input_pipelines"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "get_per_replica_batch_size"
|
||||
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,8 @@
|
||||
path: "tensorflow.distribute.InputReplicationMode"
|
||||
tf_class {
|
||||
is_instance: "<enum \'InputReplicationMode\'>"
|
||||
member {
|
||||
name: "PER_WORKER"
|
||||
mtype: "<enum \'InputReplicationMode\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
path: "tensorflow.distribute.ReduceOp"
|
||||
tf_class {
|
||||
is_instance: "<enum \'ReduceOp\'>"
|
||||
member {
|
||||
name: "MEAN"
|
||||
mtype: "<enum \'ReduceOp\'>"
|
||||
}
|
||||
member {
|
||||
name: "SUM"
|
||||
mtype: "<enum \'ReduceOp\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
path: "tensorflow.distribute.ReplicaContext"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "distribution_strategy"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "replica_id_in_sync_group"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "strategy"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "merge_call"
|
||||
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
}
|
@ -0,0 +1,81 @@
|
||||
path: "tensorflow.distribute.StrategyExtended"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "experimental_between_graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_require_static_shapes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "experimental_should_init"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "parameter_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_save_summary"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "worker_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast_to"
|
||||
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "call_for_each_replica"
|
||||
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_run_steps_on_iterator"
|
||||
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "non_slot_devices"
|
||||
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "read_var"
|
||||
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reduce_to"
|
||||
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "update_non_slot"
|
||||
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "value_container"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,133 @@
|
||||
path: "tensorflow.distribute.Strategy"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "between_graph"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "extended"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "num_replicas_in_sync"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "parameter_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "require_static_shapes"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_checkpoint"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_init"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "should_save_summary"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "worker_devices"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "batch_reduce"
|
||||
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "broadcast"
|
||||
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "call_for_each_replica"
|
||||
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "colocate_vars_with"
|
||||
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "configure"
|
||||
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "distribute_dataset"
|
||||
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "experimental_initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "finalize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "group"
|
||||
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "initialize"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_dataset_iterator"
|
||||
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "make_input_fn_iterator"
|
||||
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "non_slot_devices"
|
||||
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "read_var"
|
||||
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "reduce"
|
||||
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "run_steps_on_dataset"
|
||||
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "scope"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "unwrap"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "update"
|
||||
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "update_non_slot"
|
||||
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "value_container"
|
||||
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
47
tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
Normal file
47
tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
Normal file
@ -0,0 +1,47 @@
|
||||
path: "tensorflow.distribute"
|
||||
tf_module {
|
||||
member {
|
||||
name: "InputContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "InputReplicationMode"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "ReduceOp"
|
||||
mtype: "<class \'enum.EnumMeta\'>"
|
||||
}
|
||||
member {
|
||||
name: "ReplicaContext"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "Strategy"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "StrategyExtended"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "get_loss_reduction"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_replica_context"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_strategy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "has_strategy"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "in_cross_replica_context"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -180,6 +180,10 @@ tf_module {
|
||||
name: "debugging"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "distribute"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "double"
|
||||
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user