Create tf.distribute namespace with the base classes and standard APIs
for distribution strategies. Strategy implementations will be in a future change. Also rename DistributionStrategy to tf.distribute.Strategy, and other changes to match. RELNOTES: Expose tf.distribute.Strategy as the new name for tf.contrib.distribute.DistributionStrategy. PiperOrigin-RevId: 222097121
This commit is contained in:
parent
5c60fb7e9b
commit
a4bcedd676
tensorflow
python
distribute
tools/api/generator
training
tools/api/golden
v1
tensorflow.distribute.-input-context.pbtxttensorflow.distribute.-input-replication-mode.pbtxttensorflow.distribute.-reduce-op.pbtxttensorflow.distribute.-replica-context.pbtxttensorflow.distribute.-strategy-extended.pbtxttensorflow.distribute.-strategy.pbtxttensorflow.distribute.pbtxttensorflow.pbtxt
v2
tensorflow.distribute.-input-context.pbtxttensorflow.distribute.-input-replication-mode.pbtxttensorflow.distribute.-reduce-op.pbtxttensorflow.distribute.-replica-context.pbtxttensorflow.distribute.-strategy-extended.pbtxttensorflow.distribute.-strategy.pbtxttensorflow.distribute.pbtxttensorflow.pbtxt
@ -216,7 +216,10 @@ py_library(
|
|||||||
py_library(
|
py_library(
|
||||||
name = "reduce_util",
|
name = "reduce_util",
|
||||||
srcs = ["reduce_util.py"],
|
srcs = ["reduce_util.py"],
|
||||||
deps = [],
|
deps = [
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -21,9 +21,10 @@ from __future__ import print_function
|
|||||||
import enum
|
import enum
|
||||||
|
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# TODO(priyag): Add this to tf.distribute namespace when it exists.
|
@tf_export("distribute.ReduceOp")
|
||||||
class ReduceOp(enum.Enum):
|
class ReduceOp(enum.Enum):
|
||||||
"""Indicates how a set of values should be reduced.
|
"""Indicates how a set of values should be reduced.
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ TENSORFLOW_API_INIT_FILES = [
|
|||||||
"data/__init__.py",
|
"data/__init__.py",
|
||||||
"data/experimental/__init__.py",
|
"data/experimental/__init__.py",
|
||||||
"debugging/__init__.py",
|
"debugging/__init__.py",
|
||||||
|
"distribute/__init__.py",
|
||||||
"dtypes/__init__.py",
|
"dtypes/__init__.py",
|
||||||
"errors/__init__.py",
|
"errors/__init__.py",
|
||||||
"experimental/__init__.py",
|
"experimental/__init__.py",
|
||||||
|
@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
|
|||||||
"data/__init__.py",
|
"data/__init__.py",
|
||||||
"data/experimental/__init__.py",
|
"data/experimental/__init__.py",
|
||||||
"debugging/__init__.py",
|
"debugging/__init__.py",
|
||||||
|
"distribute/__init__.py",
|
||||||
"distributions/__init__.py",
|
"distributions/__init__.py",
|
||||||
"dtypes/__init__.py",
|
"dtypes/__init__.py",
|
||||||
"errors/__init__.py",
|
"errors/__init__.py",
|
||||||
|
@ -35,10 +35,11 @@ DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
|
|||||||
|
|
||||||
_TENSORFLOW_DOC_SOURCES = {
|
_TENSORFLOW_DOC_SOURCES = {
|
||||||
'app': DocSource(docstring_module_name='platform.app'),
|
'app': DocSource(docstring_module_name='platform.app'),
|
||||||
|
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
|
||||||
'compat': DocSource(docstring_module_name='util.compat'),
|
'compat': DocSource(docstring_module_name='util.compat'),
|
||||||
|
'distribute': DocSource(docstring_module_name='training.distribute'),
|
||||||
'distributions': DocSource(
|
'distributions': DocSource(
|
||||||
docstring_module_name='ops.distributions.distributions'),
|
docstring_module_name='ops.distributions.distributions'),
|
||||||
'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
|
|
||||||
'errors': DocSource(docstring_module_name='framework.errors'),
|
'errors': DocSource(docstring_module_name='framework.errors'),
|
||||||
'gfile': DocSource(docstring_module_name='platform.gfile'),
|
'gfile': DocSource(docstring_module_name='platform.gfile'),
|
||||||
'graph_util': DocSource(docstring_module_name='framework.graph_util'),
|
'graph_util': DocSource(docstring_module_name='framework.graph_util'),
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Class DistributionStrategy, ReplicaContext, and supporting APIs."""
|
"""Library for running a computation across multiple devices."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -38,19 +38,19 @@ from tensorflow.python.platform import tf_logging
|
|||||||
from tensorflow.python.training import device_util
|
from tensorflow.python.training import device_util
|
||||||
from tensorflow.python.training import distribution_strategy_context
|
from tensorflow.python.training import distribution_strategy_context
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
from tensorflow.tools.docs import doc_controls
|
from tensorflow.tools.docs import doc_controls
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
# Context tracking whether in a distribution.update() or .update_non_slot()
|
# Context tracking whether in a strategy.update() or .update_non_slot() call.
|
||||||
# call.
|
|
||||||
|
|
||||||
|
|
||||||
_update_device = threading.local()
|
_update_device = threading.local()
|
||||||
|
|
||||||
|
|
||||||
def get_update_device():
|
def get_update_device():
|
||||||
"""Get the current device if in a `DistributionStrategy.update()` call."""
|
"""Get the current device if in a `tf.distribute.Strategy.update()` call."""
|
||||||
try:
|
try:
|
||||||
return _update_device.current
|
return _update_device.current
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@ -77,8 +77,9 @@ class UpdateContext(object):
|
|||||||
# Public utility functions.
|
# Public utility functions.
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.get_loss_reduction")
|
||||||
def get_loss_reduction():
|
def get_loss_reduction():
|
||||||
"""Reduce op corresponding to the last loss reduction."""
|
"""`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
|
||||||
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
|
loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
|
||||||
if loss_reduction == losses_impl.Reduction.SUM:
|
if loss_reduction == losses_impl.Reduction.SUM:
|
||||||
return reduce_util.ReduceOp.SUM
|
return reduce_util.ReduceOp.SUM
|
||||||
@ -95,25 +96,25 @@ def _require_cross_replica_context_extended(extended):
|
|||||||
cross_replica = context.cross_replica_context
|
cross_replica = context.cross_replica_context
|
||||||
if cross_replica is not None and cross_replica.extended is extended:
|
if cross_replica is not None and cross_replica.extended is extended:
|
||||||
return
|
return
|
||||||
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access
|
strategy = extended._container_strategy() # pylint: disable=protected-access
|
||||||
# We have an error to report, figure out the right message.
|
# We have an error to report, figure out the right message.
|
||||||
if context.distribution_strategy is not distribution_strategy:
|
if context.distribution_strategy is not strategy:
|
||||||
_wrong_distribution_strategy_scope(distribution_strategy, context)
|
_wrong_strategy_scope(strategy, context)
|
||||||
assert cross_replica is None
|
assert cross_replica is None
|
||||||
raise RuntimeError("Method requires being in cross-replica context, use "
|
raise RuntimeError("Method requires being in cross-replica context, use "
|
||||||
"get_replica_context().merge_call()")
|
"get_replica_context().merge_call()")
|
||||||
|
|
||||||
|
|
||||||
def _wrong_distribution_strategy_scope(distribution_strategy, context):
|
def _wrong_strategy_scope(strategy, context):
|
||||||
# Figure out the right error message.
|
# Figure out the right error message.
|
||||||
if not distribution_strategy_context.has_distribution_strategy():
|
if not distribution_strategy_context.has_distribution_strategy():
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
'Need to be inside "with distribution_strategy.scope()" for %s' %
|
'Need to be inside "with strategy.scope()" for %s' %
|
||||||
(distribution_strategy,))
|
(strategy,))
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Mixing different DistributionStrategy objects: %s is not %s" %
|
"Mixing different tf.distribute.Strategy objects: %s is not %s" %
|
||||||
(context.distribution_strategy, distribution_strategy))
|
(context.distribution_strategy, strategy))
|
||||||
|
|
||||||
|
|
||||||
def require_replica_context(replica_ctx):
|
def require_replica_context(replica_ctx):
|
||||||
@ -124,18 +125,18 @@ def require_replica_context(replica_ctx):
|
|||||||
if context.replica_context is None:
|
if context.replica_context is None:
|
||||||
raise RuntimeError("Need to be inside `call_for_each_replica()`")
|
raise RuntimeError("Need to be inside `call_for_each_replica()`")
|
||||||
if context.distribution_strategy is replica_ctx.distribution_strategy:
|
if context.distribution_strategy is replica_ctx.distribution_strategy:
|
||||||
# Two different ReplicaContexts with the same DistributionStrategy.
|
# Two different ReplicaContexts with the same tf.distribute.Strategy.
|
||||||
raise RuntimeError("Mismatching ReplicaContext.")
|
raise RuntimeError("Mismatching ReplicaContext.")
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Mismatching DistributionStrategy objects: %s is not %s." %
|
"Mismatching tf.distribute.Strategy objects: %s is not %s." %
|
||||||
(context.distribution_strategy, replica_ctx.distribution_strategy))
|
(context.distribution_strategy, replica_ctx.distribution_strategy))
|
||||||
|
|
||||||
|
|
||||||
def _require_distribution_strategy_scope_strategy(distribution_strategy):
|
def _require_distribution_strategy_scope_strategy(strategy):
|
||||||
"""Verify in a `distribution_strategy.scope()` in this thread."""
|
"""Verify in a `strategy.scope()` in this thread."""
|
||||||
context = _get_per_thread_mode()
|
context = _get_per_thread_mode()
|
||||||
if context.distribution_strategy is distribution_strategy: return
|
if context.distribution_strategy is strategy: return
|
||||||
_wrong_distribution_strategy_scope(distribution_strategy, context)
|
_wrong_strategy_scope(strategy, context)
|
||||||
|
|
||||||
|
|
||||||
def _require_distribution_strategy_scope_extended(extended):
|
def _require_distribution_strategy_scope_extended(extended):
|
||||||
@ -143,8 +144,8 @@ def _require_distribution_strategy_scope_extended(extended):
|
|||||||
context = _get_per_thread_mode()
|
context = _get_per_thread_mode()
|
||||||
if context.distribution_strategy.extended is extended: return
|
if context.distribution_strategy.extended is extended: return
|
||||||
# Report error.
|
# Report error.
|
||||||
distribution_strategy = extended._container_strategy() # pylint: disable=protected-access
|
strategy = extended._container_strategy() # pylint: disable=protected-access
|
||||||
_wrong_distribution_strategy_scope(distribution_strategy, context)
|
_wrong_strategy_scope(strategy, context)
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
@ -153,15 +154,18 @@ def _require_distribution_strategy_scope_extended(extended):
|
|||||||
|
|
||||||
|
|
||||||
class _CurrentDistributionContext(object):
|
class _CurrentDistributionContext(object):
|
||||||
"""Context manager for setting the `DistributionStrategy` and var creator."""
|
"""Context manager setting the current `tf.distribute.Strategy`.
|
||||||
|
|
||||||
|
Also: overrides the variable creator and optionally the current device.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
distribution_strategy,
|
strategy,
|
||||||
var_creator_scope,
|
var_creator_scope,
|
||||||
var_scope=None,
|
var_scope=None,
|
||||||
default_device=None):
|
default_device=None):
|
||||||
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
|
self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access
|
||||||
distribution_strategy)
|
strategy)
|
||||||
self._var_creator_scope = var_creator_scope
|
self._var_creator_scope = var_creator_scope
|
||||||
self._var_scope = var_scope
|
self._var_scope = var_scope
|
||||||
if default_device:
|
if default_device:
|
||||||
@ -190,8 +194,8 @@ class _CurrentDistributionContext(object):
|
|||||||
class _SameScopeAgainContext(object):
|
class _SameScopeAgainContext(object):
|
||||||
"""Trivial context manager when you are already in `scope()`."""
|
"""Trivial context manager when you are already in `scope()`."""
|
||||||
|
|
||||||
def __init__(self, distribution_strategy):
|
def __init__(self, strategy):
|
||||||
self._distribution_strategy = distribution_strategy
|
self._distribution_strategy = strategy
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self._distribution_strategy
|
return self._distribution_strategy
|
||||||
@ -201,6 +205,7 @@ class _SameScopeAgainContext(object):
|
|||||||
|
|
||||||
|
|
||||||
# TODO(yuefengz): add more replication modes.
|
# TODO(yuefengz): add more replication modes.
|
||||||
|
@tf_export("distribute.InputReplicationMode")
|
||||||
class InputReplicationMode(enum.Enum):
|
class InputReplicationMode(enum.Enum):
|
||||||
"""Replication mode for input function."""
|
"""Replication mode for input function."""
|
||||||
|
|
||||||
@ -211,6 +216,7 @@ class InputReplicationMode(enum.Enum):
|
|||||||
PER_WORKER = "PER_WORKER"
|
PER_WORKER = "PER_WORKER"
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.InputContext")
|
||||||
class InputContext(object):
|
class InputContext(object):
|
||||||
"""A class wrapping information needed by an input function.
|
"""A class wrapping information needed by an input function.
|
||||||
|
|
||||||
@ -278,6 +284,7 @@ class InputContext(object):
|
|||||||
# Base classes for all distribution strategies.
|
# Base classes for all distribution strategies.
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.Strategy")
|
||||||
class DistributionStrategy(object):
|
class DistributionStrategy(object):
|
||||||
"""A list of devices with a state & compute distribution policy.
|
"""A list of devices with a state & compute distribution policy.
|
||||||
|
|
||||||
@ -301,14 +308,14 @@ class DistributionStrategy(object):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def extended(self):
|
def extended(self):
|
||||||
"""`tf.contrib.distribute.DistributionStrategyExtended` with new methods."""
|
"""`tf.distribute.StrategyExtended` with additional methods."""
|
||||||
return self._extended
|
return self._extended
|
||||||
|
|
||||||
def scope(self):
|
def scope(self):
|
||||||
"""Returns a context manager selecting this DistributionStrategy as current.
|
"""Returns a context manager selecting this Strategy as current.
|
||||||
|
|
||||||
Inside a `with distribution_strategy.scope():` code block, this thread
|
Inside a `with strategy.scope():` code block, this thread
|
||||||
will use a variable creator set by `distribution_strategy`, and will
|
will use a variable creator set by `strategy`, and will
|
||||||
enter its "cross-replica context".
|
enter its "cross-replica context".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -330,20 +337,20 @@ class DistributionStrategy(object):
|
|||||||
def distribute_dataset(self, dataset_fn):
|
def distribute_dataset(self, dataset_fn):
|
||||||
"""Return a `dataset` split across all replicas. DEPRECATED.
|
"""Return a `dataset` split across all replicas. DEPRECATED.
|
||||||
|
|
||||||
DEPRECATED: Please use `make_dataset_iterator()` or
|
DEPRECATED: Please use `make_dataset_iterator` or
|
||||||
`make_input_fn_iterator()` instead.
|
`make_input_fn_iterator` instead.
|
||||||
|
|
||||||
Suitable for providing input to for `extended.call_for_each_replica()` by
|
Suitable for providing input to `extended.call_for_each_replica()` by
|
||||||
creating an iterator:
|
creating an iterator:
|
||||||
|
|
||||||
```
|
```
|
||||||
def dataset_fn():
|
def dataset_fn():
|
||||||
return tf.data.Dataset.from_tensors([[1.]]).repeat()
|
return tf.data.Dataset.from_tensors([[1.]]).repeat()
|
||||||
|
|
||||||
with distribution_strategy.scope():
|
with strategy.scope():
|
||||||
distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
|
distributed_dataset = strategy.distribute_dataset(dataset_fn)
|
||||||
iterator = distributed_dataset.make_initializable_iterator()
|
iterator = distributed_dataset.make_initializable_iterator()
|
||||||
replica_results = distribution_strategy.extended.call_for_each_replica(
|
replica_results = strategy.extended.call_for_each_replica(
|
||||||
replica_fn, args=(iterator.get_next(),))
|
replica_fn, args=(iterator.get_next(),))
|
||||||
```
|
```
|
||||||
|
|
||||||
@ -374,8 +381,8 @@ class DistributionStrategy(object):
|
|||||||
replicas.
|
replicas.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An `InputIterator` which returns inputs for each step of the computation.
|
An `tf.distribute.InputIterator` which returns inputs for each step of the
|
||||||
User should call `initialize` on the returned iterator.
|
computation. User should call `initialize` on the returned iterator.
|
||||||
"""
|
"""
|
||||||
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
|
return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -384,26 +391,26 @@ class DistributionStrategy(object):
|
|||||||
replication_mode=InputReplicationMode.PER_WORKER):
|
replication_mode=InputReplicationMode.PER_WORKER):
|
||||||
"""Returns an iterator split across replicas created from an input function.
|
"""Returns an iterator split across replicas created from an input function.
|
||||||
|
|
||||||
The `input_fn` should take an `InputContext` object where information about
|
The `input_fn` should take an `tf.distribute.InputContext` object where
|
||||||
input sharding can be accessed:
|
information about input sharding can be accessed:
|
||||||
|
|
||||||
```
|
```
|
||||||
def input_fn(input_context):
|
def input_fn(input_context):
|
||||||
d = tf.data.Dataset.from_tensors([[1.]]).repeat()
|
d = tf.data.Dataset.from_tensors([[1.]]).repeat()
|
||||||
return d.shard(input_context.num_input_pipelines,
|
return d.shard(input_context.num_input_pipelines,
|
||||||
input_context.input_pipeline_id)
|
input_context.input_pipeline_id)
|
||||||
with distribution_strategy.scope():
|
with strategy.scope():
|
||||||
iterator = distribution_strategy.make_input_fn_iterator(
|
iterator = strategy.make_input_fn_iterator(
|
||||||
input_fn)
|
input_fn)
|
||||||
replica_results = distribution_strategy.call_for_each_replica(
|
replica_results = strategy.extended.call_for_each_replica(
|
||||||
replica_fn, iterator.get_next())
|
replica_fn, iterator.get_next())
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_fn: A function that returns a `tf.data.Dataset`. This function is
|
input_fn: A function that returns a `tf.data.Dataset`. This function is
|
||||||
expected to take an `InputContext` object.
|
expected to take an `tf.distribute.InputContext` object.
|
||||||
replication_mode: an enum value of `InputReplicationMode`. Only
|
replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
|
||||||
`PER_WORKER` is supported currently.
|
Only `PER_WORKER` is supported currently.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An iterator object that can be initialized and fetched next element.
|
An iterator object that can be initialized and fetched next element.
|
||||||
@ -544,7 +551,7 @@ class DistributionStrategy(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
value: A value returned by `extended.call_for_each_replica()` or a
|
value: A value returned by `extended.call_for_each_replica()` or a
|
||||||
variable created in `scope()`.
|
variable created in `scope`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of values contained in `value`. If `value` represents a single
|
A list of values contained in `value`. If `value` represents a single
|
||||||
@ -638,17 +645,19 @@ class DistributionStrategy(object):
|
|||||||
raise RuntimeError("Must only deepcopy DistributionStrategy.")
|
raise RuntimeError("Must only deepcopy DistributionStrategy.")
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.StrategyExtended")
|
||||||
class DistributionStrategyExtended(object):
|
class DistributionStrategyExtended(object):
|
||||||
"""Additional APIs for algorithms that need to be distribution-aware.
|
"""Additional APIs for algorithms that need to be distribution-aware.
|
||||||
|
|
||||||
The intent is that you can write an algorithm in a stylized way and
|
The intent is that you can write an algorithm in a stylized way and
|
||||||
it will be usable with a variety of different `DistributionStrategy`
|
it will be usable with a variety of different
|
||||||
|
`tf.distribute.Strategy`
|
||||||
implementations. Each descendant will implement a different strategy
|
implementations. Each descendant will implement a different strategy
|
||||||
for distributing the algorithm across multiple devices/machines.
|
for distributing the algorithm across multiple devices/machines.
|
||||||
Furthermore, these changes can be hidden inside the specific layers
|
Furthermore, these changes can be hidden inside the specific layers
|
||||||
and other library classes that need special treatment to run in a
|
and other library classes that need special treatment to run in a
|
||||||
distributed setting, so that most users' model definition code can
|
distributed setting, so that most users' model definition code can
|
||||||
run unchanged. The `DistributionStrategy` API works the same way
|
run unchanged. The `tf.distribute.Strategy` API works the same way
|
||||||
with eager and graph execution.
|
with eager and graph execution.
|
||||||
|
|
||||||
First let's introduce a few high-level concepts:
|
First let's introduce a few high-level concepts:
|
||||||
@ -696,42 +705,33 @@ class DistributionStrategyExtended(object):
|
|||||||
|
|
||||||
We have then a few approaches we want to support:
|
We have then a few approaches we want to support:
|
||||||
|
|
||||||
* Code written (as if) with no knowledge of class `DistributionStrategy`.
|
* Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
|
||||||
This code should work as before, even if some of the layers, etc.
|
This code should work as before, even if some of the layers, etc.
|
||||||
used by that code are written to be distribution-aware. This is done
|
used by that code are written to be distribution-aware. This is done
|
||||||
by having a default `DistributionStrategy` that gives ordinary behavior,
|
by having a default `tf.distribute.Strategy` that gives ordinary behavior,
|
||||||
and by default being in a single replica context.
|
and by default being in a single replica context.
|
||||||
* Ordinary model code that you want to run using a specific
|
* Ordinary model code that you want to run using a specific
|
||||||
`DistributionStrategy`. This can be as simple as:
|
`tf.distribute.Strategy`. This can be as simple as:
|
||||||
|
|
||||||
```
|
```
|
||||||
with my_distribution.scope():
|
with my_strategy.scope():
|
||||||
iterator = my_distribution.distribute_dataset(
|
iterator = my_strategy.make_dataset_iterator(dataset)
|
||||||
dataset).make_one_shot_iterator()
|
session.run(iterator.initialize())
|
||||||
replica_train_ops = my_distribution.extended.call_for_each_replica(
|
replica_train_ops = my_strategy.extended.call_for_each_replica(
|
||||||
replica_fn, args=(iterator.get_next(),))
|
replica_fn, args=(iterator.get_next(),))
|
||||||
train_op = tf.group(my_distribution.unwrap(replica_train_ops))
|
train_op = my_strategy.group(replica_train_ops)
|
||||||
```
|
```
|
||||||
|
|
||||||
This takes an ordinary `dataset` and `replica_fn` and runs it
|
This takes an ordinary `dataset` and `replica_fn` and runs it
|
||||||
distributed using a particular `DistributionStrategy` in
|
distributed using a particular `tf.distribute.Strategy` in
|
||||||
`my_distribution`. Any variables created in `replica_fn` are created
|
`my_strategy`. Any variables created in `replica_fn` are created
|
||||||
using `my_distribution`'s policy, and library functions called by
|
using `my_strategy`'s policy, and library functions called by
|
||||||
`replica_fn` can use the `get_replica_context()` API to get enhanced
|
`replica_fn` can use the `get_replica_context()` API to get enhanced
|
||||||
behavior in this case.
|
behavior in this case.
|
||||||
|
|
||||||
You can also create an initializable iterator instead of a one-shot
|
|
||||||
iterator. In that case, you will need to ensure that you initialize the
|
|
||||||
iterator before calling get_next.
|
|
||||||
```
|
|
||||||
iterator = my_distribution.distribute_dataset(
|
|
||||||
dataset).make_initializable_iterator())
|
|
||||||
session.run(iterator.initializer)
|
|
||||||
```
|
|
||||||
|
|
||||||
* If you want to write a distributed algorithm, you may use any of
|
* If you want to write a distributed algorithm, you may use any of
|
||||||
the `DistributionStrategy` APIs inside a
|
the `tf.distribute.Strategy` APIs inside a
|
||||||
`with my_distribution.scope():` block of code.
|
`with my_strategy.scope():` block of code.
|
||||||
|
|
||||||
Lower-level concepts:
|
Lower-level concepts:
|
||||||
|
|
||||||
@ -758,7 +758,7 @@ class DistributionStrategyExtended(object):
|
|||||||
* Replica context vs. Cross-replica context: _replica context_ is when we
|
* Replica context vs. Cross-replica context: _replica context_ is when we
|
||||||
are in some function that is being called once for each replica.
|
are in some function that is being called once for each replica.
|
||||||
Otherwise we are in cross-replica context, which is useful for
|
Otherwise we are in cross-replica context, which is useful for
|
||||||
calling `DistributionStrategy` methods which operate across the
|
calling `tf.distribute.Strategy` methods which operate across the
|
||||||
replicas (like `reduce_to()`). By default you start in a replica context
|
replicas (like `reduce_to()`). By default you start in a replica context
|
||||||
(the default "single replica context") and then some methods can
|
(the default "single replica context") and then some methods can
|
||||||
switch you back and forth, as described below.
|
switch you back and forth, as described below.
|
||||||
@ -778,7 +778,7 @@ class DistributionStrategyExtended(object):
|
|||||||
pick a consistent set of devices to pass to both
|
pick a consistent set of devices to pass to both
|
||||||
`colocate_vars_with()` and `update_non_slot()`.
|
`colocate_vars_with()` and `update_non_slot()`.
|
||||||
|
|
||||||
When using a `DistributionStrategy`, we have a new type dimension
|
When using a `tf.distribute.Strategy`, we have a new type dimension
|
||||||
called _locality_ that says what values are compatible with which
|
called _locality_ that says what values are compatible with which
|
||||||
APIs:
|
APIs:
|
||||||
|
|
||||||
@ -856,19 +856,20 @@ class DistributionStrategyExtended(object):
|
|||||||
input and destination.
|
input and destination.
|
||||||
|
|
||||||
Layers should expect to be called in a replica context, and can use
|
Layers should expect to be called in a replica context, and can use
|
||||||
the `get_replica_context()` function to get a `ReplicaContext` object. The
|
the `tf.distribute.get_replica_context` function to get a
|
||||||
|
`tf.distribute.ReplicaContext` object. The
|
||||||
`ReplicaContext` object has a `merge_call()` method for entering
|
`ReplicaContext` object has a `merge_call()` method for entering
|
||||||
cross-replica context where you can use `reduce_to()` (or
|
cross-replica context where you can use `reduce_to()` (or
|
||||||
`batch_reduce_to()`) and then optionally `update()` to update state.
|
`batch_reduce_to()`) and then optionally `update()` to update state.
|
||||||
|
|
||||||
You may use this API whether or not a `DistributionStrategy` is
|
You may use this API whether or not a `tf.distribute.Strategy` is
|
||||||
being used, since there is a default implementation of
|
being used, since there is a default implementation of
|
||||||
`ReplicaContext` and `DistributionStrategy`.
|
`ReplicaContext` and `tf.distribute.Strategy`.
|
||||||
|
|
||||||
NOTE for new `DistributionStrategy` implementations: Please put all logic
|
NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
|
||||||
in a subclass of `DistributionStrategyExtended`. The only code needed for
|
in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
|
||||||
the `DistributionStrategy` subclass is for instantiating your subclass of
|
the `tf.distribute.Strategy` subclass is for instantiating your subclass of
|
||||||
`DistributionStrategyExtended` in the `__init__` method.
|
`tf.distribute.StrategyExtended` in the `__init__` method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, container_strategy):
|
def __init__(self, container_strategy):
|
||||||
@ -908,7 +909,7 @@ class DistributionStrategyExtended(object):
|
|||||||
if kwargs.pop("partitioner", None) is not None:
|
if kwargs.pop("partitioner", None) is not None:
|
||||||
tf_logging.log_first_n(
|
tf_logging.log_first_n(
|
||||||
tf_logging.WARN, "Partitioned variables are disabled when using "
|
tf_logging.WARN, "Partitioned variables are disabled when using "
|
||||||
"current DistributionStrategy.", 1)
|
"current tf.distribute.Strategy.", 1)
|
||||||
return getter(*args, **kwargs)
|
return getter(*args, **kwargs)
|
||||||
|
|
||||||
return _CurrentDistributionContext(
|
return _CurrentDistributionContext(
|
||||||
@ -932,7 +933,7 @@ class DistributionStrategyExtended(object):
|
|||||||
(read-only) value of any other variable.
|
(read-only) value of any other variable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
v: A variable allocated within the scope of this `DistributionStrategy`.
|
v: A variable allocated within the scope of this `tf.distribute.Strategy`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tensor representing the value of `v`, aggregated across replicas if
|
A tensor representing the value of `v`, aggregated across replicas if
|
||||||
@ -953,9 +954,9 @@ class DistributionStrategyExtended(object):
|
|||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
```
|
```
|
||||||
with distribution_strategy.scope():
|
with strategy.scope():
|
||||||
var1 = tf.get_variable(...)
|
var1 = tf.get_variable(...)
|
||||||
with distribution_strategy.colocate_vars_with(v1):
|
with strategy.extended.colocate_vars_with(v1):
|
||||||
# var2 and var3 will be created on the same device(s) as var1
|
# var2 and var3 will be created on the same device(s) as var1
|
||||||
var2 = tf.get_variable(...)
|
var2 = tf.get_variable(...)
|
||||||
var3 = tf.get_variable(...)
|
var3 = tf.get_variable(...)
|
||||||
@ -964,7 +965,7 @@ class DistributionStrategyExtended(object):
|
|||||||
# operates on v1 from var1, v2 from var2, and v3 from var3
|
# operates on v1 from var1, v2 from var2, and v3 from var3
|
||||||
|
|
||||||
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
|
# `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
|
||||||
distribution_strategy.update(v1, fn, args=(v2, v3))
|
strategy.extended.update(v1, fn, args=(v2, v3))
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -990,7 +991,7 @@ class DistributionStrategyExtended(object):
|
|||||||
if not isinstance(result, dataset_ops.Dataset):
|
if not isinstance(result, dataset_ops.Dataset):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"dataset_fn() must return a tf.data.Dataset when using a "
|
"dataset_fn() must return a tf.data.Dataset when using a "
|
||||||
"DistributionStrategy.")
|
"tf.distribute.Strategy.")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
|
# TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
|
||||||
@ -1320,7 +1321,7 @@ class DistributionStrategyExtended(object):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
var_list: The list of variables being optimized, needed with the
|
var_list: The list of variables being optimized, needed with the
|
||||||
default `DistributionStrategy`.
|
default `tf.distribute.Strategy`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("must be implemented in descendants")
|
raise NotImplementedError("must be implemented in descendants")
|
||||||
|
|
||||||
@ -1374,13 +1375,18 @@ class DistributionStrategyExtended(object):
|
|||||||
# around their model creation and graph definition. There is no
|
# around their model creation and graph definition. There is no
|
||||||
# anticipated need to define descendants of _CurrentDistributionContext.
|
# anticipated need to define descendants of _CurrentDistributionContext.
|
||||||
# It sets the current DistributionStrategy for purposes of
|
# It sets the current DistributionStrategy for purposes of
|
||||||
# `get_distribution_strategy()` and `has_distribution_strategy()`
|
# `get_strategy()` and `has_strategy()`
|
||||||
# and switches the thread mode to a "cross-replica context".
|
# and switches the thread mode to a "cross-replica context".
|
||||||
|
@tf_export("distribute.ReplicaContext")
|
||||||
class ReplicaContext(object):
|
class ReplicaContext(object):
|
||||||
"""DistributionStrategy API inside a `call_for_each_replica()` call."""
|
"""`tf.distribute.Strategy` API when in a replica context.
|
||||||
|
|
||||||
def __init__(self, distribution_strategy, replica_id_in_sync_group):
|
To be used inside your replicated step function, such as in a
|
||||||
self._distribution_strategy = distribution_strategy
|
`tf.distribute.StrategyExtended.call_for_each_replica` call.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strategy, replica_id_in_sync_group):
|
||||||
|
self._distribution_strategy = strategy
|
||||||
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
|
self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access
|
||||||
self)
|
self)
|
||||||
self._replica_id_in_sync_group = replica_id_in_sync_group
|
self._replica_id_in_sync_group = replica_id_in_sync_group
|
||||||
@ -1396,23 +1402,24 @@ class ReplicaContext(object):
|
|||||||
|
|
||||||
This allows communication and coordination when there are multiple calls
|
This allows communication and coordination when there are multiple calls
|
||||||
to a model function triggered by a call to
|
to a model function triggered by a call to
|
||||||
`distribution.call_for_each_replica(model_fn, ...)`.
|
`strategy.extended.call_for_each_replica(model_fn, ...)`.
|
||||||
|
|
||||||
See `MirroredDistribution.call_for_each_replica()` for an explanation.
|
See `tf.distribute.StrategyExtended.call_for_each_replica` for an
|
||||||
|
explanation.
|
||||||
|
|
||||||
Otherwise, this is equivalent to:
|
If not inside a distributed scope, this is equivalent to:
|
||||||
|
|
||||||
```
|
```
|
||||||
distribution = get_distribution_strategy()
|
strategy = tf.distribute.get_strategy()
|
||||||
with cross-replica-context(distribution):
|
with cross-replica-context(strategy):
|
||||||
return merge_fn(distribution, *args, **kwargs)
|
return merge_fn(strategy, *args, **kwargs)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
merge_fn: function that joins arguments from threads that are given as
|
merge_fn: function that joins arguments from threads that are given as
|
||||||
PerReplica. It accepts `DistributionStrategy` object as the first
|
PerReplica. It accepts `tf.distribute.Strategy` object as
|
||||||
argument.
|
the first argument.
|
||||||
args: List or tuple with positional per-thread arguments for `merge_fn`
|
args: List or tuple with positional per-thread arguments for `merge_fn`.
|
||||||
kwargs: Dict with keyword per-thread arguments for `merge_fn`.
|
kwargs: Dict with keyword per-thread arguments for `merge_fn`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@ -1446,8 +1453,14 @@ class ReplicaContext(object):
|
|||||||
return self._replica_id_in_sync_group
|
return self._replica_id_in_sync_group
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@doc_controls.do_not_generate_docs # DEPRECATED, use `strategy`
|
||||||
def distribution_strategy(self):
|
def distribution_strategy(self):
|
||||||
"""The current `DistributionStrategy` object."""
|
"""DEPRECATED: use `self.stratgey` instead."""
|
||||||
|
return self._distribution_strategy
|
||||||
|
|
||||||
|
@property
|
||||||
|
def strategy(self):
|
||||||
|
"""The current `tf.distribute.Strategy` object."""
|
||||||
return self._distribution_strategy
|
return self._distribution_strategy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -1470,7 +1483,7 @@ class ReplicaContext(object):
|
|||||||
|
|
||||||
|
|
||||||
class _DefaultDistributionStrategy(DistributionStrategy):
|
class _DefaultDistributionStrategy(DistributionStrategy):
|
||||||
"""Default `DistributionStrategy` if none is explicitly selected."""
|
"""Default `tf.distribute.Strategy` if none is explicitly selected."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(_DefaultDistributionStrategy, self).__init__(
|
super(_DefaultDistributionStrategy, self).__init__(
|
||||||
@ -1483,7 +1496,7 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
|
|||||||
def _scope(self, strategy):
|
def _scope(self, strategy):
|
||||||
"""Context manager setting a variable creator and `self` as current."""
|
"""Context manager setting a variable creator and `self` as current."""
|
||||||
if distribution_strategy_context.has_distribution_strategy():
|
if distribution_strategy_context.has_distribution_strategy():
|
||||||
raise RuntimeError("Must not nest DistributionStrategy scopes.")
|
raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
|
||||||
|
|
||||||
def creator(next_creator, *args, **kwargs):
|
def creator(next_creator, *args, **kwargs):
|
||||||
_require_distribution_strategy_scope_strategy(strategy)
|
_require_distribution_strategy_scope_strategy(strategy)
|
||||||
@ -1555,13 +1568,13 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def worker_devices(self):
|
def worker_devices(self):
|
||||||
raise RuntimeError(
|
raise RuntimeError("worker_devices() method unsupported by default "
|
||||||
"worker_devices() method unsupported by _DefaultDistributionStrategy.")
|
"tf.distribute.Strategy.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameter_devices(self):
|
def parameter_devices(self):
|
||||||
raise RuntimeError("parameter_devices() method unsupported by "
|
raise RuntimeError("parameter_devices() method unsupported by default "
|
||||||
"_DefaultDistributionStrategy.")
|
"tf.distribute.Strategy.")
|
||||||
|
|
||||||
def non_slot_devices(self, var_list):
|
def non_slot_devices(self, var_list):
|
||||||
return min(var_list, key=lambda x: x.name)
|
return min(var_list, key=lambda x: x.name)
|
||||||
@ -1600,8 +1613,8 @@ _original_from_proto = resource_variable_ops._from_proto_fn
|
|||||||
def _from_proto_fn(v, import_scope=None):
|
def _from_proto_fn(v, import_scope=None):
|
||||||
if distribution_strategy_context.has_distribution_strategy():
|
if distribution_strategy_context.has_distribution_strategy():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Deserialization of variables is not yet supported when using"
|
"Deserialization of variables is not yet supported when using a "
|
||||||
"distributed strategies.")
|
"tf.distribute.Strategy.")
|
||||||
else:
|
else:
|
||||||
return _original_from_proto(v, import_scope=import_scope)
|
return _original_from_proto(v, import_scope=import_scope)
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
# There is a circular dependency between this and `distribute` module. So we
|
# There is a circular dependency between this and `distribute` module. So we
|
||||||
@ -85,6 +86,7 @@ def _get_per_thread_mode():
|
|||||||
# Public API for accessing the current thread mode
|
# Public API for accessing the current thread mode
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.get_replica_context")
|
||||||
def get_replica_context():
|
def get_replica_context():
|
||||||
"""Returns the current `tf.distribute.ReplicaContext` or `None`.
|
"""Returns the current `tf.distribute.ReplicaContext` or `None`.
|
||||||
|
|
||||||
@ -95,7 +97,7 @@ def get_replica_context():
|
|||||||
1. starts in the default (single-replica) replica context (this function
|
1. starts in the default (single-replica) replica context (this function
|
||||||
will return the default `ReplicaContext` object);
|
will return the default `ReplicaContext` object);
|
||||||
2. switches to cross-replica context (in which case this will return
|
2. switches to cross-replica context (in which case this will return
|
||||||
`None`) when entering a `with DistributionStrategy.scope():` block;
|
`None`) when entering a `with tf.distribute.Strategy.scope():` block;
|
||||||
3. switches to a (non-default) replica context inside
|
3. switches to a (non-default) replica context inside
|
||||||
`extended.call_for_each_replica(fn, ...)`;
|
`extended.call_for_each_replica(fn, ...)`;
|
||||||
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
|
4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
|
||||||
@ -103,11 +105,11 @@ def get_replica_context():
|
|||||||
this function will return `None`).
|
this function will return `None`).
|
||||||
|
|
||||||
Note that you can also go directly from step 1 to 4 to switch to a
|
Note that you can also go directly from step 1 to 4 to switch to a
|
||||||
cross-replica context for the default `DistributionStrategy`. You may
|
cross-replica context for the default `tf.distribute.Strategy`. You may
|
||||||
also switch from the cross-replica context of 4 to a replica context by
|
also switch from the cross-replica context of 4 to a replica context by
|
||||||
calling `extended.call_for_each_replica()`, jumping back to step 3.
|
calling `extended.call_for_each_replica()`, jumping back to step 3.
|
||||||
|
|
||||||
Most `DistributionStrategy` methods may only be executed in
|
Most `tf.distribute.Strategy` methods may only be executed in
|
||||||
a cross-replica context, in a replica context you should use the
|
a cross-replica context, in a replica context you should use the
|
||||||
`ReplicaContext` API instead.
|
`ReplicaContext` API instead.
|
||||||
|
|
||||||
@ -124,7 +126,7 @@ def get_replica_context():
|
|||||||
|
|
||||||
|
|
||||||
def get_cross_replica_context():
|
def get_cross_replica_context():
|
||||||
"""Returns the current DistributionStrategy if in a cross-replica context.
|
"""Returns the current tf.distribute.Strategy if in a cross-replica context.
|
||||||
|
|
||||||
DEPRECATED: Please use `in_cross_replica_context()` and
|
DEPRECATED: Please use `in_cross_replica_context()` and
|
||||||
`get_distribution_strategy()` instead.
|
`get_distribution_strategy()` instead.
|
||||||
@ -133,22 +135,22 @@ def get_cross_replica_context():
|
|||||||
|
|
||||||
1. starts in the default (single-replica) replica context;
|
1. starts in the default (single-replica) replica context;
|
||||||
2. switches to cross-replica context when entering a
|
2. switches to cross-replica context when entering a
|
||||||
`with DistributionStrategy.scope():` block;
|
`with tf.distribute.Strategy.scope():` block;
|
||||||
3. switches to a (non-default) replica context inside
|
3. switches to a (non-default) replica context inside
|
||||||
`call_for_each_replica(fn, ...)`;
|
`call_for_each_replica(fn, ...)`;
|
||||||
4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
|
4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
|
||||||
inside `merge_fn` you are back in the cross-replica context.
|
inside `merge_fn` you are back in the cross-replica context.
|
||||||
|
|
||||||
Note that you can also go directly from step 1 to 4 to switch to a
|
Note that you can also go directly from step 1 to 4 to switch to a
|
||||||
cross-replica context for the default `DistributionStrategy`. You may
|
cross-replica context for the default `tf.distribute.Strategy`. You may
|
||||||
also switch from the cross-replica context of 4 to a replica context by
|
also switch from the cross-replica context of 4 to a replica context by
|
||||||
calling `call_for_each_replica()`, jumping back to step 3.
|
calling `call_for_each_replica()`, jumping back to step 3.
|
||||||
|
|
||||||
Most `DistributionStrategy` methods may only be executed in
|
Most `tf.distribute.Strategy` methods may only be executed in
|
||||||
a cross-replica context.
|
a cross-replica context.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Returns the current `DistributionStrategy` object in a cross-replica
|
Returns the current `tf.distribute.Strategy` object in a cross-replica
|
||||||
context, or `None`.
|
context, or `None`.
|
||||||
|
|
||||||
Exactly one of `get_replica_context()` and `get_cross_replica_context()`
|
Exactly one of `get_replica_context()` and `get_cross_replica_context()`
|
||||||
@ -157,6 +159,7 @@ def get_cross_replica_context():
|
|||||||
return _get_per_thread_mode().cross_replica_context
|
return _get_per_thread_mode().cross_replica_context
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.in_cross_replica_context")
|
||||||
def in_cross_replica_context():
|
def in_cross_replica_context():
|
||||||
"""Returns True if in a cross-replica context.
|
"""Returns True if in a cross-replica context.
|
||||||
|
|
||||||
@ -170,31 +173,33 @@ def in_cross_replica_context():
|
|||||||
return _get_per_thread_mode().cross_replica_context is not None
|
return _get_per_thread_mode().cross_replica_context is not None
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.get_strategy")
|
||||||
def get_distribution_strategy():
|
def get_distribution_strategy():
|
||||||
"""Returns the current `DistributionStrategy` object.
|
"""Returns the current `tf.distribute.Strategy` object.
|
||||||
|
|
||||||
Typically only used in a cross-replica context:
|
Typically only used in a cross-replica context:
|
||||||
|
|
||||||
```
|
```
|
||||||
if tf.distribute.in_cross_replica_context():
|
if tf.distribute.in_cross_replica_context():
|
||||||
strategy = tf.distribute.get_distribution_strategy()
|
strategy = tf.distribute.get_strategy()
|
||||||
...
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `DistributionStrategy` object. Inside a
|
A `tf.distribute.Strategy` object. Inside a
|
||||||
`with distribution_strategy.scope()` block, it returns
|
`with distribution_strategy.scope()` block, it returns
|
||||||
`distribution_strategy`, otherwise it returns the default
|
`distribution_strategy`, otherwise it returns the default
|
||||||
(single-replica) `DistributionStrategy` object.
|
(single-replica) `tf.distribute.Strategy` object.
|
||||||
"""
|
"""
|
||||||
return _get_per_thread_mode().distribution_strategy
|
return _get_per_thread_mode().distribution_strategy
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("distribute.has_strategy")
|
||||||
def has_distribution_strategy():
|
def has_distribution_strategy():
|
||||||
"""Return if there is a current non-default `DistributionStrategy`.
|
"""Return if there is a current non-default `tf.distribute.Strategy`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if inside a `with distribution_strategy.scope():`.
|
True if inside a `with strategy.scope():`.
|
||||||
"""
|
"""
|
||||||
return get_distribution_strategy() is not _get_default_distribution_strategy()
|
return get_distribution_strategy() is not _get_default_distribution_strategy()
|
||||||
|
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
path: "tensorflow.distribute.InputContext"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "input_pipeline_id"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_input_pipelines"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_per_replica_batch_size"
|
||||||
|
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,8 @@
|
|||||||
|
path: "tensorflow.distribute.InputReplicationMode"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<enum \'InputReplicationMode\'>"
|
||||||
|
member {
|
||||||
|
name: "PER_WORKER"
|
||||||
|
mtype: "<enum \'InputReplicationMode\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
path: "tensorflow.distribute.ReduceOp"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<enum \'ReduceOp\'>"
|
||||||
|
member {
|
||||||
|
name: "MEAN"
|
||||||
|
mtype: "<enum \'ReduceOp\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "SUM"
|
||||||
|
mtype: "<enum \'ReduceOp\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,33 @@
|
|||||||
|
path: "tensorflow.distribute.ReplicaContext"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "distribution_strategy"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "replica_id_in_sync_group"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "strategy"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "merge_call"
|
||||||
|
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,81 @@
|
|||||||
|
path: "tensorflow.distribute.StrategyExtended"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "experimental_between_graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "experimental_require_static_shapes"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "experimental_should_init"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "parameter_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_checkpoint"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_save_summary"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "worker_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "batch_reduce_to"
|
||||||
|
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "broadcast_to"
|
||||||
|
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call_for_each_replica"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "colocate_vars_with"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_steps_on_iterator"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "non_slot_devices"
|
||||||
|
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "read_var"
|
||||||
|
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "reduce_to"
|
||||||
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update"
|
||||||
|
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update_non_slot"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "value_container"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,133 @@
|
|||||||
|
path: "tensorflow.distribute.Strategy"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "between_graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "extended"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "parameter_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "require_static_shapes"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_checkpoint"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_init"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_save_summary"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "worker_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "batch_reduce"
|
||||||
|
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "broadcast"
|
||||||
|
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call_for_each_replica"
|
||||||
|
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "colocate_vars_with"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "configure"
|
||||||
|
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "distribute_dataset"
|
||||||
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_finalize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_initialize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "finalize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "group"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "initialize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_dataset_iterator"
|
||||||
|
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_input_fn_iterator"
|
||||||
|
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "non_slot_devices"
|
||||||
|
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "read_var"
|
||||||
|
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "reduce"
|
||||||
|
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "run_steps_on_dataset"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scope"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unwrap"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update"
|
||||||
|
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update_non_slot"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "value_container"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
47
tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt
Normal file
47
tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
path: "tensorflow.distribute"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "InputContext"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "InputReplicationMode"
|
||||||
|
mtype: "<class \'enum.EnumMeta\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "ReduceOp"
|
||||||
|
mtype: "<class \'enum.EnumMeta\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "ReplicaContext"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "Strategy"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "StrategyExtended"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_loss_reduction"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_replica_context"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_strategy"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "has_strategy"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "in_cross_replica_context"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -324,6 +324,10 @@ tf_module {
|
|||||||
name: "debugging"
|
name: "debugging"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "distribute"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "distributions"
|
name: "distributions"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
|
@ -0,0 +1,25 @@
|
|||||||
|
path: "tensorflow.distribute.InputContext"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "input_pipeline_id"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_input_pipelines"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_per_replica_batch_size"
|
||||||
|
argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,8 @@
|
|||||||
|
path: "tensorflow.distribute.InputReplicationMode"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<enum \'InputReplicationMode\'>"
|
||||||
|
member {
|
||||||
|
name: "PER_WORKER"
|
||||||
|
mtype: "<enum \'InputReplicationMode\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,12 @@
|
|||||||
|
path: "tensorflow.distribute.ReduceOp"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<enum \'ReduceOp\'>"
|
||||||
|
member {
|
||||||
|
name: "MEAN"
|
||||||
|
mtype: "<enum \'ReduceOp\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "SUM"
|
||||||
|
mtype: "<enum \'ReduceOp\'>"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,33 @@
|
|||||||
|
path: "tensorflow.distribute.ReplicaContext"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "distribution_strategy"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "replica_id_in_sync_group"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "strategy"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "merge_call"
|
||||||
|
argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,81 @@
|
|||||||
|
path: "tensorflow.distribute.StrategyExtended"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "experimental_between_graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "experimental_require_static_shapes"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "experimental_should_init"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "parameter_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_checkpoint"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_save_summary"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "worker_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "batch_reduce_to"
|
||||||
|
argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "broadcast_to"
|
||||||
|
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call_for_each_replica"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "colocate_vars_with"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_run_steps_on_iterator"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "non_slot_devices"
|
||||||
|
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "read_var"
|
||||||
|
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "reduce_to"
|
||||||
|
argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update"
|
||||||
|
argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update_non_slot"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "value_container"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,133 @@
|
|||||||
|
path: "tensorflow.distribute.Strategy"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "between_graph"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "extended"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "num_replicas_in_sync"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "parameter_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "require_static_shapes"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_checkpoint"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_init"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "should_save_summary"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "worker_devices"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "batch_reduce"
|
||||||
|
argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "broadcast"
|
||||||
|
argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "call_for_each_replica"
|
||||||
|
argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "colocate_vars_with"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "configure"
|
||||||
|
argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "distribute_dataset"
|
||||||
|
argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_finalize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "experimental_initialize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "finalize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "group"
|
||||||
|
argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "initialize"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_dataset_iterator"
|
||||||
|
argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "make_input_fn_iterator"
|
||||||
|
argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "non_slot_devices"
|
||||||
|
argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "read_var"
|
||||||
|
argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "reduce"
|
||||||
|
argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "run_steps_on_dataset"
|
||||||
|
argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "scope"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "unwrap"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update"
|
||||||
|
argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "update_non_slot"
|
||||||
|
argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "value_container"
|
||||||
|
argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
47
tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
Normal file
47
tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
path: "tensorflow.distribute"
|
||||||
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "InputContext"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "InputReplicationMode"
|
||||||
|
mtype: "<class \'enum.EnumMeta\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "ReduceOp"
|
||||||
|
mtype: "<class \'enum.EnumMeta\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "ReplicaContext"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "Strategy"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "StrategyExtended"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_loss_reduction"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_replica_context"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_strategy"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "has_strategy"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "in_cross_replica_context"
|
||||||
|
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -180,6 +180,10 @@ tf_module {
|
|||||||
name: "debugging"
|
name: "debugging"
|
||||||
mtype: "<type \'module\'>"
|
mtype: "<type \'module\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "distribute"
|
||||||
|
mtype: "<type \'module\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "double"
|
name: "double"
|
||||||
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
|
||||||
|
Loading…
Reference in New Issue
Block a user