From a4bcedd676ea035eb73a2d27437a542f9683d59d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Mon, 19 Nov 2018 10:21:43 -0800 Subject: [PATCH] Create tf.distribute namespace with the base classes and standard APIs for distribution strategies. Strategy implementations will be in a future change. Also rename DistributionStrategy to tf.distribute.Strategy, and other changes to match. RELNOTES: Expose tf.distribute.Strategy as the new name for tf.contrib.distribute.DistributionStrategy. PiperOrigin-RevId: 222097121 --- tensorflow/python/distribute/BUILD | 5 +- tensorflow/python/distribute/reduce_util.py | 3 +- .../tools/api/generator/api_init_files.bzl | 1 + .../tools/api/generator/api_init_files_v1.bzl | 1 + .../python/tools/api/generator/doc_srcs.py | 3 +- tensorflow/python/training/distribute.py | 231 +++++++++--------- .../training/distribution_strategy_context.py | 33 +-- ...tensorflow.distribute.-input-context.pbtxt | 25 ++ ...w.distribute.-input-replication-mode.pbtxt | 8 + .../v1/tensorflow.distribute.-reduce-op.pbtxt | 12 + ...nsorflow.distribute.-replica-context.pbtxt | 33 +++ ...orflow.distribute.-strategy-extended.pbtxt | 81 ++++++ .../v1/tensorflow.distribute.-strategy.pbtxt | 133 ++++++++++ .../api/golden/v1/tensorflow.distribute.pbtxt | 47 ++++ .../tools/api/golden/v1/tensorflow.pbtxt | 4 + ...tensorflow.distribute.-input-context.pbtxt | 25 ++ ...w.distribute.-input-replication-mode.pbtxt | 8 + .../v2/tensorflow.distribute.-reduce-op.pbtxt | 12 + ...nsorflow.distribute.-replica-context.pbtxt | 33 +++ ...orflow.distribute.-strategy-extended.pbtxt | 81 ++++++ .../v2/tensorflow.distribute.-strategy.pbtxt | 133 ++++++++++ .../api/golden/v2/tensorflow.distribute.pbtxt | 47 ++++ .../tools/api/golden/v2/tensorflow.pbtxt | 4 + 23 files changed, 837 insertions(+), 126 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 78ad8e7e158..04592ee6edf 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -216,7 +216,10 @@ py_library( py_library( name = "reduce_util", srcs = ["reduce_util.py"], - deps = [], + deps = [ + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + ], ) py_library( diff --git a/tensorflow/python/distribute/reduce_util.py b/tensorflow/python/distribute/reduce_util.py index 857c7993df0..2b2a4e9dba8 100644 --- a/tensorflow/python/distribute/reduce_util.py +++ b/tensorflow/python/distribute/reduce_util.py @@ -21,9 +21,10 @@ from __future__ import print_function import enum from tensorflow.python.ops import variable_scope +from tensorflow.python.util.tf_export import tf_export -# TODO(priyag): Add this to tf.distribute namespace when it exists. +@tf_export("distribute.ReduceOp") class ReduceOp(enum.Enum): """Indicates how a set of values should be reduced. diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl index 2e3d1d93c4e..b41a1bc8f6f 100644 --- a/tensorflow/python/tools/api/generator/api_init_files.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files.bzl @@ -9,6 +9,7 @@ TENSORFLOW_API_INIT_FILES = [ "data/__init__.py", "data/experimental/__init__.py", "debugging/__init__.py", + "distribute/__init__.py", "dtypes/__init__.py", "errors/__init__.py", "experimental/__init__.py", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index 89c817f6090..0fadec00ab0 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "data/__init__.py", "data/experimental/__init__.py", "debugging/__init__.py", + "distribute/__init__.py", "distributions/__init__.py", "dtypes/__init__.py", "errors/__init__.py", diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py index 0524b5e35cd..9e211d172ec 100644 --- a/tensorflow/python/tools/api/generator/doc_srcs.py +++ b/tensorflow/python/tools/api/generator/doc_srcs.py @@ -35,10 +35,11 @@ DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields) _TENSORFLOW_DOC_SOURCES = { 'app': DocSource(docstring_module_name='platform.app'), + 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'), 'compat': DocSource(docstring_module_name='util.compat'), + 'distribute': DocSource(docstring_module_name='training.distribute'), 'distributions': DocSource( docstring_module_name='ops.distributions.distributions'), - 'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'), 'errors': DocSource(docstring_module_name='framework.errors'), 'gfile': DocSource(docstring_module_name='platform.gfile'), 'graph_util': DocSource(docstring_module_name='framework.graph_util'), diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 7d68dd904f4..646c352a716 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Class DistributionStrategy, ReplicaContext, and supporting APIs.""" +"""Library for running a computation across multiple devices.""" from __future__ import absolute_import from __future__ import division @@ -38,19 +38,19 @@ from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import nest +from tensorflow.python.util.tf_export import tf_export from tensorflow.tools.docs import doc_controls # ------------------------------------------------------------------------------ -# Context tracking whether in a distribution.update() or .update_non_slot() -# call. +# Context tracking whether in a strategy.update() or .update_non_slot() call. _update_device = threading.local() def get_update_device(): - """Get the current device if in a `DistributionStrategy.update()` call.""" + """Get the current device if in a `tf.distribute.Strategy.update()` call.""" try: return _update_device.current except AttributeError: @@ -77,8 +77,9 @@ class UpdateContext(object): # Public utility functions. +@tf_export("distribute.get_loss_reduction") def get_loss_reduction(): - """Reduce op corresponding to the last loss reduction.""" + """`tf.distribute.ReduceOp` corresponding to the last loss reduction.""" loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access if loss_reduction == losses_impl.Reduction.SUM: return reduce_util.ReduceOp.SUM @@ -95,25 +96,25 @@ def _require_cross_replica_context_extended(extended): cross_replica = context.cross_replica_context if cross_replica is not None and cross_replica.extended is extended: return - distribution_strategy = extended._container_strategy() # pylint: disable=protected-access + strategy = extended._container_strategy() # pylint: disable=protected-access # We have an error to report, figure out the right message. - if context.distribution_strategy is not distribution_strategy: - _wrong_distribution_strategy_scope(distribution_strategy, context) + if context.distribution_strategy is not strategy: + _wrong_strategy_scope(strategy, context) assert cross_replica is None raise RuntimeError("Method requires being in cross-replica context, use " "get_replica_context().merge_call()") -def _wrong_distribution_strategy_scope(distribution_strategy, context): +def _wrong_strategy_scope(strategy, context): # Figure out the right error message. if not distribution_strategy_context.has_distribution_strategy(): raise RuntimeError( - 'Need to be inside "with distribution_strategy.scope()" for %s' % - (distribution_strategy,)) + 'Need to be inside "with strategy.scope()" for %s' % + (strategy,)) else: raise RuntimeError( - "Mixing different DistributionStrategy objects: %s is not %s" % - (context.distribution_strategy, distribution_strategy)) + "Mixing different tf.distribute.Strategy objects: %s is not %s" % + (context.distribution_strategy, strategy)) def require_replica_context(replica_ctx): @@ -124,18 +125,18 @@ def require_replica_context(replica_ctx): if context.replica_context is None: raise RuntimeError("Need to be inside `call_for_each_replica()`") if context.distribution_strategy is replica_ctx.distribution_strategy: - # Two different ReplicaContexts with the same DistributionStrategy. + # Two different ReplicaContexts with the same tf.distribute.Strategy. raise RuntimeError("Mismatching ReplicaContext.") raise RuntimeError( - "Mismatching DistributionStrategy objects: %s is not %s." % + "Mismatching tf.distribute.Strategy objects: %s is not %s." % (context.distribution_strategy, replica_ctx.distribution_strategy)) -def _require_distribution_strategy_scope_strategy(distribution_strategy): - """Verify in a `distribution_strategy.scope()` in this thread.""" +def _require_distribution_strategy_scope_strategy(strategy): + """Verify in a `strategy.scope()` in this thread.""" context = _get_per_thread_mode() - if context.distribution_strategy is distribution_strategy: return - _wrong_distribution_strategy_scope(distribution_strategy, context) + if context.distribution_strategy is strategy: return + _wrong_strategy_scope(strategy, context) def _require_distribution_strategy_scope_extended(extended): @@ -143,8 +144,8 @@ def _require_distribution_strategy_scope_extended(extended): context = _get_per_thread_mode() if context.distribution_strategy.extended is extended: return # Report error. - distribution_strategy = extended._container_strategy() # pylint: disable=protected-access - _wrong_distribution_strategy_scope(distribution_strategy, context) + strategy = extended._container_strategy() # pylint: disable=protected-access + _wrong_strategy_scope(strategy, context) # ------------------------------------------------------------------------------ @@ -153,15 +154,18 @@ def _require_distribution_strategy_scope_extended(extended): class _CurrentDistributionContext(object): - """Context manager for setting the `DistributionStrategy` and var creator.""" + """Context manager setting the current `tf.distribute.Strategy`. + + Also: overrides the variable creator and optionally the current device. + """ def __init__(self, - distribution_strategy, + strategy, var_creator_scope, var_scope=None, default_device=None): self._context = distribution_strategy_context._CrossReplicaThreadMode( # pylint: disable=protected-access - distribution_strategy) + strategy) self._var_creator_scope = var_creator_scope self._var_scope = var_scope if default_device: @@ -190,8 +194,8 @@ class _CurrentDistributionContext(object): class _SameScopeAgainContext(object): """Trivial context manager when you are already in `scope()`.""" - def __init__(self, distribution_strategy): - self._distribution_strategy = distribution_strategy + def __init__(self, strategy): + self._distribution_strategy = strategy def __enter__(self): return self._distribution_strategy @@ -201,6 +205,7 @@ class _SameScopeAgainContext(object): # TODO(yuefengz): add more replication modes. +@tf_export("distribute.InputReplicationMode") class InputReplicationMode(enum.Enum): """Replication mode for input function.""" @@ -211,6 +216,7 @@ class InputReplicationMode(enum.Enum): PER_WORKER = "PER_WORKER" +@tf_export("distribute.InputContext") class InputContext(object): """A class wrapping information needed by an input function. @@ -278,6 +284,7 @@ class InputContext(object): # Base classes for all distribution strategies. +@tf_export("distribute.Strategy") class DistributionStrategy(object): """A list of devices with a state & compute distribution policy. @@ -301,14 +308,14 @@ class DistributionStrategy(object): @property def extended(self): - """`tf.contrib.distribute.DistributionStrategyExtended` with new methods.""" + """`tf.distribute.StrategyExtended` with additional methods.""" return self._extended def scope(self): - """Returns a context manager selecting this DistributionStrategy as current. + """Returns a context manager selecting this Strategy as current. - Inside a `with distribution_strategy.scope():` code block, this thread - will use a variable creator set by `distribution_strategy`, and will + Inside a `with strategy.scope():` code block, this thread + will use a variable creator set by `strategy`, and will enter its "cross-replica context". Returns: @@ -330,20 +337,20 @@ class DistributionStrategy(object): def distribute_dataset(self, dataset_fn): """Return a `dataset` split across all replicas. DEPRECATED. - DEPRECATED: Please use `make_dataset_iterator()` or - `make_input_fn_iterator()` instead. + DEPRECATED: Please use `make_dataset_iterator` or + `make_input_fn_iterator` instead. - Suitable for providing input to for `extended.call_for_each_replica()` by + Suitable for providing input to `extended.call_for_each_replica()` by creating an iterator: ``` def dataset_fn(): return tf.data.Dataset.from_tensors([[1.]]).repeat() - with distribution_strategy.scope(): - distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn) + with strategy.scope(): + distributed_dataset = strategy.distribute_dataset(dataset_fn) iterator = distributed_dataset.make_initializable_iterator() - replica_results = distribution_strategy.extended.call_for_each_replica( + replica_results = strategy.extended.call_for_each_replica( replica_fn, args=(iterator.get_next(),)) ``` @@ -374,8 +381,8 @@ class DistributionStrategy(object): replicas. Returns: - An `InputIterator` which returns inputs for each step of the computation. - User should call `initialize` on the returned iterator. + An `tf.distribute.InputIterator` which returns inputs for each step of the + computation. User should call `initialize` on the returned iterator. """ return self._extended._make_dataset_iterator(dataset) # pylint: disable=protected-access @@ -384,26 +391,26 @@ class DistributionStrategy(object): replication_mode=InputReplicationMode.PER_WORKER): """Returns an iterator split across replicas created from an input function. - The `input_fn` should take an `InputContext` object where information about - input sharding can be accessed: + The `input_fn` should take an `tf.distribute.InputContext` object where + information about input sharding can be accessed: ``` def input_fn(input_context): d = tf.data.Dataset.from_tensors([[1.]]).repeat() return d.shard(input_context.num_input_pipelines, input_context.input_pipeline_id) - with distribution_strategy.scope(): - iterator = distribution_strategy.make_input_fn_iterator( + with strategy.scope(): + iterator = strategy.make_input_fn_iterator( input_fn) - replica_results = distribution_strategy.call_for_each_replica( + replica_results = strategy.extended.call_for_each_replica( replica_fn, iterator.get_next()) ``` Args: input_fn: A function that returns a `tf.data.Dataset`. This function is - expected to take an `InputContext` object. - replication_mode: an enum value of `InputReplicationMode`. Only - `PER_WORKER` is supported currently. + expected to take an `tf.distribute.InputContext` object. + replication_mode: an enum value of `tf.distribute.InputReplicationMode`. + Only `PER_WORKER` is supported currently. Returns: An iterator object that can be initialized and fetched next element. @@ -544,7 +551,7 @@ class DistributionStrategy(object): Args: value: A value returned by `extended.call_for_each_replica()` or a - variable created in `scope()`. + variable created in `scope`. Returns: A list of values contained in `value`. If `value` represents a single @@ -638,17 +645,19 @@ class DistributionStrategy(object): raise RuntimeError("Must only deepcopy DistributionStrategy.") +@tf_export("distribute.StrategyExtended") class DistributionStrategyExtended(object): """Additional APIs for algorithms that need to be distribution-aware. The intent is that you can write an algorithm in a stylized way and - it will be usable with a variety of different `DistributionStrategy` + it will be usable with a variety of different + `tf.distribute.Strategy` implementations. Each descendant will implement a different strategy for distributing the algorithm across multiple devices/machines. Furthermore, these changes can be hidden inside the specific layers and other library classes that need special treatment to run in a distributed setting, so that most users' model definition code can - run unchanged. The `DistributionStrategy` API works the same way + run unchanged. The `tf.distribute.Strategy` API works the same way with eager and graph execution. First let's introduce a few high-level concepts: @@ -696,42 +705,33 @@ class DistributionStrategyExtended(object): We have then a few approaches we want to support: - * Code written (as if) with no knowledge of class `DistributionStrategy`. + * Code written (as if) with no knowledge of class `tf.distribute.Strategy`. This code should work as before, even if some of the layers, etc. used by that code are written to be distribution-aware. This is done - by having a default `DistributionStrategy` that gives ordinary behavior, + by having a default `tf.distribute.Strategy` that gives ordinary behavior, and by default being in a single replica context. * Ordinary model code that you want to run using a specific - `DistributionStrategy`. This can be as simple as: + `tf.distribute.Strategy`. This can be as simple as: ``` - with my_distribution.scope(): - iterator = my_distribution.distribute_dataset( - dataset).make_one_shot_iterator() - replica_train_ops = my_distribution.extended.call_for_each_replica( + with my_strategy.scope(): + iterator = my_strategy.make_dataset_iterator(dataset) + session.run(iterator.initialize()) + replica_train_ops = my_strategy.extended.call_for_each_replica( replica_fn, args=(iterator.get_next(),)) - train_op = tf.group(my_distribution.unwrap(replica_train_ops)) + train_op = my_strategy.group(replica_train_ops) ``` This takes an ordinary `dataset` and `replica_fn` and runs it - distributed using a particular `DistributionStrategy` in - `my_distribution`. Any variables created in `replica_fn` are created - using `my_distribution`'s policy, and library functions called by + distributed using a particular `tf.distribute.Strategy` in + `my_strategy`. Any variables created in `replica_fn` are created + using `my_strategy`'s policy, and library functions called by `replica_fn` can use the `get_replica_context()` API to get enhanced behavior in this case. - You can also create an initializable iterator instead of a one-shot - iterator. In that case, you will need to ensure that you initialize the - iterator before calling get_next. - ``` - iterator = my_distribution.distribute_dataset( - dataset).make_initializable_iterator()) - session.run(iterator.initializer) - ``` - * If you want to write a distributed algorithm, you may use any of - the `DistributionStrategy` APIs inside a - `with my_distribution.scope():` block of code. + the `tf.distribute.Strategy` APIs inside a + `with my_strategy.scope():` block of code. Lower-level concepts: @@ -758,7 +758,7 @@ class DistributionStrategyExtended(object): * Replica context vs. Cross-replica context: _replica context_ is when we are in some function that is being called once for each replica. Otherwise we are in cross-replica context, which is useful for - calling `DistributionStrategy` methods which operate across the + calling `tf.distribute.Strategy` methods which operate across the replicas (like `reduce_to()`). By default you start in a replica context (the default "single replica context") and then some methods can switch you back and forth, as described below. @@ -778,7 +778,7 @@ class DistributionStrategyExtended(object): pick a consistent set of devices to pass to both `colocate_vars_with()` and `update_non_slot()`. - When using a `DistributionStrategy`, we have a new type dimension + When using a `tf.distribute.Strategy`, we have a new type dimension called _locality_ that says what values are compatible with which APIs: @@ -856,19 +856,20 @@ class DistributionStrategyExtended(object): input and destination. Layers should expect to be called in a replica context, and can use - the `get_replica_context()` function to get a `ReplicaContext` object. The + the `tf.distribute.get_replica_context` function to get a + `tf.distribute.ReplicaContext` object. The `ReplicaContext` object has a `merge_call()` method for entering cross-replica context where you can use `reduce_to()` (or `batch_reduce_to()`) and then optionally `update()` to update state. - You may use this API whether or not a `DistributionStrategy` is + You may use this API whether or not a `tf.distribute.Strategy` is being used, since there is a default implementation of - `ReplicaContext` and `DistributionStrategy`. + `ReplicaContext` and `tf.distribute.Strategy`. - NOTE for new `DistributionStrategy` implementations: Please put all logic - in a subclass of `DistributionStrategyExtended`. The only code needed for - the `DistributionStrategy` subclass is for instantiating your subclass of - `DistributionStrategyExtended` in the `__init__` method. + NOTE for new `tf.distribute.Strategy` implementations: Please put all logic + in a subclass of `tf.distribute.StrategyExtended`. The only code needed for + the `tf.distribute.Strategy` subclass is for instantiating your subclass of + `tf.distribute.StrategyExtended` in the `__init__` method. """ def __init__(self, container_strategy): @@ -908,7 +909,7 @@ class DistributionStrategyExtended(object): if kwargs.pop("partitioner", None) is not None: tf_logging.log_first_n( tf_logging.WARN, "Partitioned variables are disabled when using " - "current DistributionStrategy.", 1) + "current tf.distribute.Strategy.", 1) return getter(*args, **kwargs) return _CurrentDistributionContext( @@ -932,7 +933,7 @@ class DistributionStrategyExtended(object): (read-only) value of any other variable. Args: - v: A variable allocated within the scope of this `DistributionStrategy`. + v: A variable allocated within the scope of this `tf.distribute.Strategy`. Returns: A tensor representing the value of `v`, aggregated across replicas if @@ -953,9 +954,9 @@ class DistributionStrategyExtended(object): Example usage: ``` - with distribution_strategy.scope(): + with strategy.scope(): var1 = tf.get_variable(...) - with distribution_strategy.colocate_vars_with(v1): + with strategy.extended.colocate_vars_with(v1): # var2 and var3 will be created on the same device(s) as var1 var2 = tf.get_variable(...) var3 = tf.get_variable(...) @@ -964,7 +965,7 @@ class DistributionStrategyExtended(object): # operates on v1 from var1, v2 from var2, and v3 from var3 # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too. - distribution_strategy.update(v1, fn, args=(v2, v3)) + strategy.extended.update(v1, fn, args=(v2, v3)) ``` Args: @@ -990,7 +991,7 @@ class DistributionStrategyExtended(object): if not isinstance(result, dataset_ops.Dataset): raise ValueError( "dataset_fn() must return a tf.data.Dataset when using a " - "DistributionStrategy.") + "tf.distribute.Strategy.") return result # TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of @@ -1320,7 +1321,7 @@ class DistributionStrategyExtended(object): Args: var_list: The list of variables being optimized, needed with the - default `DistributionStrategy`. + default `tf.distribute.Strategy`. """ raise NotImplementedError("must be implemented in descendants") @@ -1374,13 +1375,18 @@ class DistributionStrategyExtended(object): # around their model creation and graph definition. There is no # anticipated need to define descendants of _CurrentDistributionContext. # It sets the current DistributionStrategy for purposes of -# `get_distribution_strategy()` and `has_distribution_strategy()` +# `get_strategy()` and `has_strategy()` # and switches the thread mode to a "cross-replica context". +@tf_export("distribute.ReplicaContext") class ReplicaContext(object): - """DistributionStrategy API inside a `call_for_each_replica()` call.""" + """`tf.distribute.Strategy` API when in a replica context. - def __init__(self, distribution_strategy, replica_id_in_sync_group): - self._distribution_strategy = distribution_strategy + To be used inside your replicated step function, such as in a + `tf.distribute.StrategyExtended.call_for_each_replica` call. + """ + + def __init__(self, strategy, replica_id_in_sync_group): + self._distribution_strategy = strategy self._thread_context = distribution_strategy_context._InReplicaThreadMode( # pylint: disable=protected-access self) self._replica_id_in_sync_group = replica_id_in_sync_group @@ -1396,23 +1402,24 @@ class ReplicaContext(object): This allows communication and coordination when there are multiple calls to a model function triggered by a call to - `distribution.call_for_each_replica(model_fn, ...)`. + `strategy.extended.call_for_each_replica(model_fn, ...)`. - See `MirroredDistribution.call_for_each_replica()` for an explanation. + See `tf.distribute.StrategyExtended.call_for_each_replica` for an + explanation. - Otherwise, this is equivalent to: + If not inside a distributed scope, this is equivalent to: ``` - distribution = get_distribution_strategy() - with cross-replica-context(distribution): - return merge_fn(distribution, *args, **kwargs) + strategy = tf.distribute.get_strategy() + with cross-replica-context(strategy): + return merge_fn(strategy, *args, **kwargs) ``` Args: merge_fn: function that joins arguments from threads that are given as - PerReplica. It accepts `DistributionStrategy` object as the first - argument. - args: List or tuple with positional per-thread arguments for `merge_fn` + PerReplica. It accepts `tf.distribute.Strategy` object as + the first argument. + args: List or tuple with positional per-thread arguments for `merge_fn`. kwargs: Dict with keyword per-thread arguments for `merge_fn`. Returns: @@ -1446,8 +1453,14 @@ class ReplicaContext(object): return self._replica_id_in_sync_group @property + @doc_controls.do_not_generate_docs # DEPRECATED, use `strategy` def distribution_strategy(self): - """The current `DistributionStrategy` object.""" + """DEPRECATED: use `self.stratgey` instead.""" + return self._distribution_strategy + + @property + def strategy(self): + """The current `tf.distribute.Strategy` object.""" return self._distribution_strategy @property @@ -1470,7 +1483,7 @@ class ReplicaContext(object): class _DefaultDistributionStrategy(DistributionStrategy): - """Default `DistributionStrategy` if none is explicitly selected.""" + """Default `tf.distribute.Strategy` if none is explicitly selected.""" def __init__(self): super(_DefaultDistributionStrategy, self).__init__( @@ -1483,7 +1496,7 @@ class _DefaultDistributionExtended(DistributionStrategyExtended): def _scope(self, strategy): """Context manager setting a variable creator and `self` as current.""" if distribution_strategy_context.has_distribution_strategy(): - raise RuntimeError("Must not nest DistributionStrategy scopes.") + raise RuntimeError("Must not nest tf.distribute.Strategy scopes.") def creator(next_creator, *args, **kwargs): _require_distribution_strategy_scope_strategy(strategy) @@ -1555,13 +1568,13 @@ class _DefaultDistributionExtended(DistributionStrategyExtended): @property def worker_devices(self): - raise RuntimeError( - "worker_devices() method unsupported by _DefaultDistributionStrategy.") + raise RuntimeError("worker_devices() method unsupported by default " + "tf.distribute.Strategy.") @property def parameter_devices(self): - raise RuntimeError("parameter_devices() method unsupported by " - "_DefaultDistributionStrategy.") + raise RuntimeError("parameter_devices() method unsupported by default " + "tf.distribute.Strategy.") def non_slot_devices(self, var_list): return min(var_list, key=lambda x: x.name) @@ -1600,8 +1613,8 @@ _original_from_proto = resource_variable_ops._from_proto_fn def _from_proto_fn(v, import_scope=None): if distribution_strategy_context.has_distribution_strategy(): raise NotImplementedError( - "Deserialization of variables is not yet supported when using" - "distributed strategies.") + "Deserialization of variables is not yet supported when using a " + "tf.distribute.Strategy.") else: return _original_from_proto(v, import_scope=import_scope) diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py index c73d65d330e..0b3878de183 100644 --- a/tensorflow/python/training/distribution_strategy_context.py +++ b/tensorflow/python/training/distribution_strategy_context.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.util.lazy_loader import LazyLoader +from tensorflow.python.util.tf_export import tf_export # There is a circular dependency between this and `distribute` module. So we @@ -85,6 +86,7 @@ def _get_per_thread_mode(): # Public API for accessing the current thread mode +@tf_export("distribute.get_replica_context") def get_replica_context(): """Returns the current `tf.distribute.ReplicaContext` or `None`. @@ -95,7 +97,7 @@ def get_replica_context(): 1. starts in the default (single-replica) replica context (this function will return the default `ReplicaContext` object); 2. switches to cross-replica context (in which case this will return - `None`) when entering a `with DistributionStrategy.scope():` block; + `None`) when entering a `with tf.distribute.Strategy.scope():` block; 3. switches to a (non-default) replica context inside `extended.call_for_each_replica(fn, ...)`; 4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then @@ -103,11 +105,11 @@ def get_replica_context(): this function will return `None`). Note that you can also go directly from step 1 to 4 to switch to a - cross-replica context for the default `DistributionStrategy`. You may + cross-replica context for the default `tf.distribute.Strategy`. You may also switch from the cross-replica context of 4 to a replica context by calling `extended.call_for_each_replica()`, jumping back to step 3. - Most `DistributionStrategy` methods may only be executed in + Most `tf.distribute.Strategy` methods may only be executed in a cross-replica context, in a replica context you should use the `ReplicaContext` API instead. @@ -124,7 +126,7 @@ def get_replica_context(): def get_cross_replica_context(): - """Returns the current DistributionStrategy if in a cross-replica context. + """Returns the current tf.distribute.Strategy if in a cross-replica context. DEPRECATED: Please use `in_cross_replica_context()` and `get_distribution_strategy()` instead. @@ -133,22 +135,22 @@ def get_cross_replica_context(): 1. starts in the default (single-replica) replica context; 2. switches to cross-replica context when entering a - `with DistributionStrategy.scope():` block; + `with tf.distribute.Strategy.scope():` block; 3. switches to a (non-default) replica context inside `call_for_each_replica(fn, ...)`; 4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then inside `merge_fn` you are back in the cross-replica context. Note that you can also go directly from step 1 to 4 to switch to a - cross-replica context for the default `DistributionStrategy`. You may + cross-replica context for the default `tf.distribute.Strategy`. You may also switch from the cross-replica context of 4 to a replica context by calling `call_for_each_replica()`, jumping back to step 3. - Most `DistributionStrategy` methods may only be executed in + Most `tf.distribute.Strategy` methods may only be executed in a cross-replica context. Returns: - Returns the current `DistributionStrategy` object in a cross-replica + Returns the current `tf.distribute.Strategy` object in a cross-replica context, or `None`. Exactly one of `get_replica_context()` and `get_cross_replica_context()` @@ -157,6 +159,7 @@ def get_cross_replica_context(): return _get_per_thread_mode().cross_replica_context +@tf_export("distribute.in_cross_replica_context") def in_cross_replica_context(): """Returns True if in a cross-replica context. @@ -170,31 +173,33 @@ def in_cross_replica_context(): return _get_per_thread_mode().cross_replica_context is not None +@tf_export("distribute.get_strategy") def get_distribution_strategy(): - """Returns the current `DistributionStrategy` object. + """Returns the current `tf.distribute.Strategy` object. Typically only used in a cross-replica context: ``` if tf.distribute.in_cross_replica_context(): - strategy = tf.distribute.get_distribution_strategy() + strategy = tf.distribute.get_strategy() ... ``` Returns: - A `DistributionStrategy` object. Inside a + A `tf.distribute.Strategy` object. Inside a `with distribution_strategy.scope()` block, it returns `distribution_strategy`, otherwise it returns the default - (single-replica) `DistributionStrategy` object. + (single-replica) `tf.distribute.Strategy` object. """ return _get_per_thread_mode().distribution_strategy +@tf_export("distribute.has_strategy") def has_distribution_strategy(): - """Return if there is a current non-default `DistributionStrategy`. + """Return if there is a current non-default `tf.distribute.Strategy`. Returns: - True if inside a `with distribution_strategy.scope():`. + True if inside a `with strategy.scope():`. """ return get_distribution_strategy() is not _get_default_distribution_strategy() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt new file mode 100644 index 00000000000..c39ac5a20de --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt @@ -0,0 +1,25 @@ +path: "tensorflow.distribute.InputContext" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>" + is_instance: "<type \'object\'>" + member { + name: "input_pipeline_id" + mtype: "<type \'property\'>" + } + member { + name: "num_input_pipelines" + mtype: "<type \'property\'>" + } + member { + name: "num_replicas_in_sync" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], " + } + member_method { + name: "get_per_replica_batch_size" + argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt new file mode 100644 index 00000000000..6a7a3a97aa0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt @@ -0,0 +1,8 @@ +path: "tensorflow.distribute.InputReplicationMode" +tf_class { + is_instance: "<enum \'InputReplicationMode\'>" + member { + name: "PER_WORKER" + mtype: "<enum \'InputReplicationMode\'>" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt new file mode 100644 index 00000000000..4899f38cad2 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.distribute.ReduceOp" +tf_class { + is_instance: "<enum \'ReduceOp\'>" + member { + name: "MEAN" + mtype: "<enum \'ReduceOp\'>" + } + member { + name: "SUM" + mtype: "<enum \'ReduceOp\'>" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt new file mode 100644 index 00000000000..3eda6c60366 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.distribute.ReplicaContext" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>" + is_instance: "<type \'object\'>" + member { + name: "devices" + mtype: "<type \'property\'>" + } + member { + name: "distribution_strategy" + mtype: "<type \'property\'>" + } + member { + name: "num_replicas_in_sync" + mtype: "<type \'property\'>" + } + member { + name: "replica_id_in_sync_group" + mtype: "<type \'property\'>" + } + member { + name: "strategy" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_call" + argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt new file mode 100644 index 00000000000..3b502b534be --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt @@ -0,0 +1,81 @@ +path: "tensorflow.distribute.StrategyExtended" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>" + is_instance: "<type \'object\'>" + member { + name: "experimental_between_graph" + mtype: "<type \'property\'>" + } + member { + name: "experimental_require_static_shapes" + mtype: "<type \'property\'>" + } + member { + name: "experimental_should_init" + mtype: "<type \'property\'>" + } + member { + name: "parameter_devices" + mtype: "<type \'property\'>" + } + member { + name: "should_checkpoint" + mtype: "<type \'property\'>" + } + member { + name: "should_save_summary" + mtype: "<type \'property\'>" + } + member { + name: "worker_devices" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch_reduce_to" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "broadcast_to" + argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call_for_each_replica" + argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " + } + member_method { + name: "colocate_vars_with" + argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_run_steps_on_iterator" + argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " + } + member_method { + name: "non_slot_devices" + argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read_var" + argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reduce_to" + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], " + } + member_method { + name: "update_non_slot" + argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], " + } + member_method { + name: "value_container" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt new file mode 100644 index 00000000000..4fe035b474f --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt @@ -0,0 +1,133 @@ +path: "tensorflow.distribute.Strategy" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>" + is_instance: "<type \'object\'>" + member { + name: "between_graph" + mtype: "<type \'property\'>" + } + member { + name: "extended" + mtype: "<type \'property\'>" + } + member { + name: "num_replicas_in_sync" + mtype: "<type \'property\'>" + } + member { + name: "parameter_devices" + mtype: "<type \'property\'>" + } + member { + name: "require_static_shapes" + mtype: "<type \'property\'>" + } + member { + name: "should_checkpoint" + mtype: "<type \'property\'>" + } + member { + name: "should_init" + mtype: "<type \'property\'>" + } + member { + name: "should_save_summary" + mtype: "<type \'property\'>" + } + member { + name: "worker_devices" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch_reduce" + argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "broadcast" + argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "call_for_each_replica" + argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "colocate_vars_with" + argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "configure" + argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "distribute_dataset" + argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_initialize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "group" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "initialize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_dataset_iterator" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_input_fn_iterator" + argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], " + } + member_method { + name: "non_slot_devices" + argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read_var" + argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reduce" + argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run_steps_on_dataset" + argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " + } + member_method { + name: "scope" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "unwrap" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "update_non_slot" + argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "value_container" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt new file mode 100644 index 00000000000..4d833b54ba0 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt @@ -0,0 +1,47 @@ +path: "tensorflow.distribute" +tf_module { + member { + name: "InputContext" + mtype: "<type \'type\'>" + } + member { + name: "InputReplicationMode" + mtype: "<class \'enum.EnumMeta\'>" + } + member { + name: "ReduceOp" + mtype: "<class \'enum.EnumMeta\'>" + } + member { + name: "ReplicaContext" + mtype: "<type \'type\'>" + } + member { + name: "Strategy" + mtype: "<type \'type\'>" + } + member { + name: "StrategyExtended" + mtype: "<type \'type\'>" + } + member_method { + name: "get_loss_reduction" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_replica_context" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_strategy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "has_strategy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "in_cross_replica_context" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index 1bb0152f19e..6a45bc7b7f4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -324,6 +324,10 @@ tf_module { name: "debugging" mtype: "<type \'module\'>" } + member { + name: "distribute" + mtype: "<type \'module\'>" + } member { name: "distributions" mtype: "<type \'module\'>" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt new file mode 100644 index 00000000000..c39ac5a20de --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt @@ -0,0 +1,25 @@ +path: "tensorflow.distribute.InputContext" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>" + is_instance: "<type \'object\'>" + member { + name: "input_pipeline_id" + mtype: "<type \'property\'>" + } + member { + name: "num_input_pipelines" + mtype: "<type \'property\'>" + } + member { + name: "num_replicas_in_sync" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], " + } + member_method { + name: "get_per_replica_batch_size" + argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt new file mode 100644 index 00000000000..6a7a3a97aa0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt @@ -0,0 +1,8 @@ +path: "tensorflow.distribute.InputReplicationMode" +tf_class { + is_instance: "<enum \'InputReplicationMode\'>" + member { + name: "PER_WORKER" + mtype: "<enum \'InputReplicationMode\'>" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt new file mode 100644 index 00000000000..4899f38cad2 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt @@ -0,0 +1,12 @@ +path: "tensorflow.distribute.ReduceOp" +tf_class { + is_instance: "<enum \'ReduceOp\'>" + member { + name: "MEAN" + mtype: "<enum \'ReduceOp\'>" + } + member { + name: "SUM" + mtype: "<enum \'ReduceOp\'>" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt new file mode 100644 index 00000000000..3eda6c60366 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt @@ -0,0 +1,33 @@ +path: "tensorflow.distribute.ReplicaContext" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>" + is_instance: "<type \'object\'>" + member { + name: "devices" + mtype: "<type \'property\'>" + } + member { + name: "distribution_strategy" + mtype: "<type \'property\'>" + } + member { + name: "num_replicas_in_sync" + mtype: "<type \'property\'>" + } + member { + name: "replica_id_in_sync_group" + mtype: "<type \'property\'>" + } + member { + name: "strategy" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "merge_call" + argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt new file mode 100644 index 00000000000..3b502b534be --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt @@ -0,0 +1,81 @@ +path: "tensorflow.distribute.StrategyExtended" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>" + is_instance: "<type \'object\'>" + member { + name: "experimental_between_graph" + mtype: "<type \'property\'>" + } + member { + name: "experimental_require_static_shapes" + mtype: "<type \'property\'>" + } + member { + name: "experimental_should_init" + mtype: "<type \'property\'>" + } + member { + name: "parameter_devices" + mtype: "<type \'property\'>" + } + member { + name: "should_checkpoint" + mtype: "<type \'property\'>" + } + member { + name: "should_save_summary" + mtype: "<type \'property\'>" + } + member { + name: "worker_devices" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch_reduce_to" + argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "broadcast_to" + argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "call_for_each_replica" + argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], " + } + member_method { + name: "colocate_vars_with" + argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_run_steps_on_iterator" + argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " + } + member_method { + name: "non_slot_devices" + argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read_var" + argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reduce_to" + argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], " + } + member_method { + name: "update_non_slot" + argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], " + } + member_method { + name: "value_container" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt new file mode 100644 index 00000000000..4fe035b474f --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt @@ -0,0 +1,133 @@ +path: "tensorflow.distribute.Strategy" +tf_class { + is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>" + is_instance: "<type \'object\'>" + member { + name: "between_graph" + mtype: "<type \'property\'>" + } + member { + name: "extended" + mtype: "<type \'property\'>" + } + member { + name: "num_replicas_in_sync" + mtype: "<type \'property\'>" + } + member { + name: "parameter_devices" + mtype: "<type \'property\'>" + } + member { + name: "require_static_shapes" + mtype: "<type \'property\'>" + } + member { + name: "should_checkpoint" + mtype: "<type \'property\'>" + } + member { + name: "should_init" + mtype: "<type \'property\'>" + } + member { + name: "should_save_summary" + mtype: "<type \'property\'>" + } + member { + name: "worker_devices" + mtype: "<type \'property\'>" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "batch_reduce" + argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "broadcast" + argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "call_for_each_replica" + argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "colocate_vars_with" + argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "configure" + argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], " + } + member_method { + name: "distribute_dataset" + argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "experimental_initialize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "finalize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "group" + argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "initialize" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_dataset_iterator" + argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "make_input_fn_iterator" + argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], " + } + member_method { + name: "non_slot_devices" + argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "read_var" + argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "reduce" + argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "run_steps_on_dataset" + argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], " + } + member_method { + name: "scope" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "unwrap" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "update_non_slot" + argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None" + } + member_method { + name: "value_container" + argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt new file mode 100644 index 00000000000..4d833b54ba0 --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt @@ -0,0 +1,47 @@ +path: "tensorflow.distribute" +tf_module { + member { + name: "InputContext" + mtype: "<type \'type\'>" + } + member { + name: "InputReplicationMode" + mtype: "<class \'enum.EnumMeta\'>" + } + member { + name: "ReduceOp" + mtype: "<class \'enum.EnumMeta\'>" + } + member { + name: "ReplicaContext" + mtype: "<type \'type\'>" + } + member { + name: "Strategy" + mtype: "<type \'type\'>" + } + member { + name: "StrategyExtended" + mtype: "<type \'type\'>" + } + member_method { + name: "get_loss_reduction" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_replica_context" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_strategy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "has_strategy" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "in_cross_replica_context" + argspec: "args=[], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 2d2ab7d0d36..10aa8b781fa 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -180,6 +180,10 @@ tf_module { name: "debugging" mtype: "<type \'module\'>" } + member { + name: "distribute" + mtype: "<type \'module\'>" + } member { name: "double" mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"