From a4bcedd676ea035eb73a2d27437a542f9683d59d Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Mon, 19 Nov 2018 10:21:43 -0800
Subject: [PATCH] Create tf.distribute namespace with the base classes and
 standard APIs for distribution strategies. Strategy implementations will be
 in a future change. Also rename DistributionStrategy to
 tf.distribute.Strategy, and other changes to match.

RELNOTES: Expose tf.distribute.Strategy as the new name for
tf.contrib.distribute.DistributionStrategy.
PiperOrigin-RevId: 222097121
---
 tensorflow/python/distribute/BUILD            |   5 +-
 tensorflow/python/distribute/reduce_util.py   |   3 +-
 .../tools/api/generator/api_init_files.bzl    |   1 +
 .../tools/api/generator/api_init_files_v1.bzl |   1 +
 .../python/tools/api/generator/doc_srcs.py    |   3 +-
 tensorflow/python/training/distribute.py      | 231 +++++++++---------
 .../training/distribution_strategy_context.py |  33 +--
 ...tensorflow.distribute.-input-context.pbtxt |  25 ++
 ...w.distribute.-input-replication-mode.pbtxt |   8 +
 .../v1/tensorflow.distribute.-reduce-op.pbtxt |  12 +
 ...nsorflow.distribute.-replica-context.pbtxt |  33 +++
 ...orflow.distribute.-strategy-extended.pbtxt |  81 ++++++
 .../v1/tensorflow.distribute.-strategy.pbtxt  | 133 ++++++++++
 .../api/golden/v1/tensorflow.distribute.pbtxt |  47 ++++
 .../tools/api/golden/v1/tensorflow.pbtxt      |   4 +
 ...tensorflow.distribute.-input-context.pbtxt |  25 ++
 ...w.distribute.-input-replication-mode.pbtxt |   8 +
 .../v2/tensorflow.distribute.-reduce-op.pbtxt |  12 +
 ...nsorflow.distribute.-replica-context.pbtxt |  33 +++
 ...orflow.distribute.-strategy-extended.pbtxt |  81 ++++++
 .../v2/tensorflow.distribute.-strategy.pbtxt  | 133 ++++++++++
 .../api/golden/v2/tensorflow.distribute.pbtxt |  47 ++++
 .../tools/api/golden/v2/tensorflow.pbtxt      |   4 +
 23 files changed, 837 insertions(+), 126 deletions(-)
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt
 create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt

diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 78ad8e7e158..04592ee6edf 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -216,7 +216,10 @@ py_library(
 py_library(
     name = "reduce_util",
     srcs = ["reduce_util.py"],
-    deps = [],
+    deps = [
+        "//tensorflow/python:util",
+        "//tensorflow/python:variable_scope",
+    ],
 )
 
 py_library(
diff --git a/tensorflow/python/distribute/reduce_util.py b/tensorflow/python/distribute/reduce_util.py
index 857c7993df0..2b2a4e9dba8 100644
--- a/tensorflow/python/distribute/reduce_util.py
+++ b/tensorflow/python/distribute/reduce_util.py
@@ -21,9 +21,10 @@ from __future__ import print_function
 import enum
 
 from tensorflow.python.ops import variable_scope
+from tensorflow.python.util.tf_export import tf_export
 
 
-# TODO(priyag): Add this to tf.distribute namespace when it exists.
+@tf_export("distribute.ReduceOp")
 class ReduceOp(enum.Enum):
   """Indicates how a set of values should be reduced.
 
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 2e3d1d93c4e..b41a1bc8f6f 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -9,6 +9,7 @@ TENSORFLOW_API_INIT_FILES = [
     "data/__init__.py",
     "data/experimental/__init__.py",
     "debugging/__init__.py",
+    "distribute/__init__.py",
     "dtypes/__init__.py",
     "errors/__init__.py",
     "experimental/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index 89c817f6090..0fadec00ab0 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -10,6 +10,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
     "data/__init__.py",
     "data/experimental/__init__.py",
     "debugging/__init__.py",
+    "distribute/__init__.py",
     "distributions/__init__.py",
     "dtypes/__init__.py",
     "errors/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/doc_srcs.py b/tensorflow/python/tools/api/generator/doc_srcs.py
index 0524b5e35cd..9e211d172ec 100644
--- a/tensorflow/python/tools/api/generator/doc_srcs.py
+++ b/tensorflow/python/tools/api/generator/doc_srcs.py
@@ -35,10 +35,11 @@ DocSource.__new__.__defaults__ = (None,) * len(DocSource._fields)
 
 _TENSORFLOW_DOC_SOURCES = {
     'app': DocSource(docstring_module_name='platform.app'),
+    'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
     'compat': DocSource(docstring_module_name='util.compat'),
+    'distribute': DocSource(docstring_module_name='training.distribute'),
     'distributions': DocSource(
         docstring_module_name='ops.distributions.distributions'),
-    'bitwise': DocSource(docstring_module_name='ops.bitwise_ops'),
     'errors': DocSource(docstring_module_name='framework.errors'),
     'gfile': DocSource(docstring_module_name='platform.gfile'),
     'graph_util': DocSource(docstring_module_name='framework.graph_util'),
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 7d68dd904f4..646c352a716 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Class DistributionStrategy, ReplicaContext, and supporting APIs."""
+"""Library for running a computation across multiple devices."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -38,19 +38,19 @@ from tensorflow.python.platform import tf_logging
 from tensorflow.python.training import device_util
 from tensorflow.python.training import distribution_strategy_context
 from tensorflow.python.util import nest
+from tensorflow.python.util.tf_export import tf_export
 from tensorflow.tools.docs import doc_controls
 
 
 # ------------------------------------------------------------------------------
-# Context tracking whether in a distribution.update() or .update_non_slot()
-# call.
+# Context tracking whether in a strategy.update() or .update_non_slot() call.
 
 
 _update_device = threading.local()
 
 
 def get_update_device():
-  """Get the current device if in a `DistributionStrategy.update()` call."""
+  """Get the current device if in a `tf.distribute.Strategy.update()` call."""
   try:
     return _update_device.current
   except AttributeError:
@@ -77,8 +77,9 @@ class UpdateContext(object):
 # Public utility functions.
 
 
+@tf_export("distribute.get_loss_reduction")
 def get_loss_reduction():
-  """Reduce op corresponding to the last loss reduction."""
+  """`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
   loss_reduction = ops.get_default_graph()._last_loss_reduction  # pylint: disable=protected-access
   if loss_reduction == losses_impl.Reduction.SUM:
     return reduce_util.ReduceOp.SUM
@@ -95,25 +96,25 @@ def _require_cross_replica_context_extended(extended):
   cross_replica = context.cross_replica_context
   if cross_replica is not None and cross_replica.extended is extended:
     return
-  distribution_strategy = extended._container_strategy()  # pylint: disable=protected-access
+  strategy = extended._container_strategy()  # pylint: disable=protected-access
   # We have an error to report, figure out the right message.
-  if context.distribution_strategy is not distribution_strategy:
-    _wrong_distribution_strategy_scope(distribution_strategy, context)
+  if context.distribution_strategy is not strategy:
+    _wrong_strategy_scope(strategy, context)
   assert cross_replica is None
   raise RuntimeError("Method requires being in cross-replica context, use "
                      "get_replica_context().merge_call()")
 
 
-def _wrong_distribution_strategy_scope(distribution_strategy, context):
+def _wrong_strategy_scope(strategy, context):
   # Figure out the right error message.
   if not distribution_strategy_context.has_distribution_strategy():
     raise RuntimeError(
-        'Need to be inside "with distribution_strategy.scope()" for %s' %
-        (distribution_strategy,))
+        'Need to be inside "with strategy.scope()" for %s' %
+        (strategy,))
   else:
     raise RuntimeError(
-        "Mixing different DistributionStrategy objects: %s is not %s" %
-        (context.distribution_strategy, distribution_strategy))
+        "Mixing different tf.distribute.Strategy objects: %s is not %s" %
+        (context.distribution_strategy, strategy))
 
 
 def require_replica_context(replica_ctx):
@@ -124,18 +125,18 @@ def require_replica_context(replica_ctx):
   if context.replica_context is None:
     raise RuntimeError("Need to be inside `call_for_each_replica()`")
   if context.distribution_strategy is replica_ctx.distribution_strategy:
-    # Two different ReplicaContexts with the same DistributionStrategy.
+    # Two different ReplicaContexts with the same tf.distribute.Strategy.
     raise RuntimeError("Mismatching ReplicaContext.")
   raise RuntimeError(
-      "Mismatching DistributionStrategy objects: %s is not %s." %
+      "Mismatching tf.distribute.Strategy objects: %s is not %s." %
       (context.distribution_strategy, replica_ctx.distribution_strategy))
 
 
-def _require_distribution_strategy_scope_strategy(distribution_strategy):
-  """Verify in a `distribution_strategy.scope()` in this thread."""
+def _require_distribution_strategy_scope_strategy(strategy):
+  """Verify in a `strategy.scope()` in this thread."""
   context = _get_per_thread_mode()
-  if context.distribution_strategy is distribution_strategy: return
-  _wrong_distribution_strategy_scope(distribution_strategy, context)
+  if context.distribution_strategy is strategy: return
+  _wrong_strategy_scope(strategy, context)
 
 
 def _require_distribution_strategy_scope_extended(extended):
@@ -143,8 +144,8 @@ def _require_distribution_strategy_scope_extended(extended):
   context = _get_per_thread_mode()
   if context.distribution_strategy.extended is extended: return
   # Report error.
-  distribution_strategy = extended._container_strategy()  # pylint: disable=protected-access
-  _wrong_distribution_strategy_scope(distribution_strategy, context)
+  strategy = extended._container_strategy()  # pylint: disable=protected-access
+  _wrong_strategy_scope(strategy, context)
 
 
 # ------------------------------------------------------------------------------
@@ -153,15 +154,18 @@ def _require_distribution_strategy_scope_extended(extended):
 
 
 class _CurrentDistributionContext(object):
-  """Context manager for setting the `DistributionStrategy` and var creator."""
+  """Context manager setting the current `tf.distribute.Strategy`.
+
+  Also: overrides the variable creator and optionally the current device.
+  """
 
   def __init__(self,
-               distribution_strategy,
+               strategy,
                var_creator_scope,
                var_scope=None,
                default_device=None):
     self._context = distribution_strategy_context._CrossReplicaThreadMode(  # pylint: disable=protected-access
-        distribution_strategy)
+        strategy)
     self._var_creator_scope = var_creator_scope
     self._var_scope = var_scope
     if default_device:
@@ -190,8 +194,8 @@ class _CurrentDistributionContext(object):
 class _SameScopeAgainContext(object):
   """Trivial context manager when you are already in `scope()`."""
 
-  def __init__(self, distribution_strategy):
-    self._distribution_strategy = distribution_strategy
+  def __init__(self, strategy):
+    self._distribution_strategy = strategy
 
   def __enter__(self):
     return self._distribution_strategy
@@ -201,6 +205,7 @@ class _SameScopeAgainContext(object):
 
 
 # TODO(yuefengz): add more replication modes.
+@tf_export("distribute.InputReplicationMode")
 class InputReplicationMode(enum.Enum):
   """Replication mode for input function."""
 
@@ -211,6 +216,7 @@ class InputReplicationMode(enum.Enum):
   PER_WORKER = "PER_WORKER"
 
 
+@tf_export("distribute.InputContext")
 class InputContext(object):
   """A class wrapping information needed by an input function.
 
@@ -278,6 +284,7 @@ class InputContext(object):
 # Base classes for all distribution strategies.
 
 
+@tf_export("distribute.Strategy")
 class DistributionStrategy(object):
   """A list of devices with a state & compute distribution policy.
 
@@ -301,14 +308,14 @@ class DistributionStrategy(object):
 
   @property
   def extended(self):
-    """`tf.contrib.distribute.DistributionStrategyExtended` with new methods."""
+    """`tf.distribute.StrategyExtended` with additional methods."""
     return self._extended
 
   def scope(self):
-    """Returns a context manager selecting this DistributionStrategy as current.
+    """Returns a context manager selecting this Strategy as current.
 
-    Inside a `with distribution_strategy.scope():` code block, this thread
-    will use a variable creator set by `distribution_strategy`, and will
+    Inside a `with strategy.scope():` code block, this thread
+    will use a variable creator set by `strategy`, and will
     enter its "cross-replica context".
 
     Returns:
@@ -330,20 +337,20 @@ class DistributionStrategy(object):
   def distribute_dataset(self, dataset_fn):
     """Return a `dataset` split across all replicas.  DEPRECATED.
 
-    DEPRECATED: Please use `make_dataset_iterator()` or
-    `make_input_fn_iterator()` instead.
+    DEPRECATED: Please use `make_dataset_iterator` or
+    `make_input_fn_iterator` instead.
 
-    Suitable for providing input to for `extended.call_for_each_replica()` by
+    Suitable for providing input to `extended.call_for_each_replica()` by
     creating an iterator:
 
     ```
     def dataset_fn():
       return tf.data.Dataset.from_tensors([[1.]]).repeat()
 
-    with distribution_strategy.scope():
-      distributed_dataset = distribution_strategy.distribute_dataset(dataset_fn)
+    with strategy.scope():
+      distributed_dataset = strategy.distribute_dataset(dataset_fn)
       iterator = distributed_dataset.make_initializable_iterator()
-      replica_results = distribution_strategy.extended.call_for_each_replica(
+      replica_results = strategy.extended.call_for_each_replica(
           replica_fn, args=(iterator.get_next(),))
     ```
 
@@ -374,8 +381,8 @@ class DistributionStrategy(object):
         replicas.
 
     Returns:
-      An `InputIterator` which returns inputs for each step of the computation.
-      User should call `initialize` on the returned iterator.
+      An `tf.distribute.InputIterator` which returns inputs for each step of the
+      computation.  User should call `initialize` on the returned iterator.
     """
     return self._extended._make_dataset_iterator(dataset)  # pylint: disable=protected-access
 
@@ -384,26 +391,26 @@ class DistributionStrategy(object):
                              replication_mode=InputReplicationMode.PER_WORKER):
     """Returns an iterator split across replicas created from an input function.
 
-    The `input_fn` should take an `InputContext` object where information about
-    input sharding can be accessed:
+    The `input_fn` should take an `tf.distribute.InputContext` object where
+    information about input sharding can be accessed:
 
     ```
     def input_fn(input_context):
       d = tf.data.Dataset.from_tensors([[1.]]).repeat()
       return d.shard(input_context.num_input_pipelines,
                      input_context.input_pipeline_id)
-    with distribution_strategy.scope():
-      iterator = distribution_strategy.make_input_fn_iterator(
+    with strategy.scope():
+      iterator = strategy.make_input_fn_iterator(
           input_fn)
-      replica_results = distribution_strategy.call_for_each_replica(
+      replica_results = strategy.extended.call_for_each_replica(
           replica_fn, iterator.get_next())
     ```
 
     Args:
       input_fn: A function that returns a `tf.data.Dataset`. This function is
-        expected to take an `InputContext` object.
-      replication_mode: an enum value of `InputReplicationMode`. Only
-        `PER_WORKER` is supported currently.
+        expected to take an `tf.distribute.InputContext` object.
+      replication_mode: an enum value of `tf.distribute.InputReplicationMode`.
+        Only `PER_WORKER` is supported currently.
 
     Returns:
       An iterator object that can be initialized and fetched next element.
@@ -544,7 +551,7 @@ class DistributionStrategy(object):
 
     Args:
       value: A value returned by `extended.call_for_each_replica()` or a
-        variable created in `scope()`.
+        variable created in `scope`.
 
     Returns:
       A list of values contained in `value`. If `value` represents a single
@@ -638,17 +645,19 @@ class DistributionStrategy(object):
     raise RuntimeError("Must only deepcopy DistributionStrategy.")
 
 
+@tf_export("distribute.StrategyExtended")
 class DistributionStrategyExtended(object):
   """Additional APIs for algorithms that need to be distribution-aware.
 
   The intent is that you can write an algorithm in a stylized way and
-  it will be usable with a variety of different `DistributionStrategy`
+  it will be usable with a variety of different
+  `tf.distribute.Strategy`
   implementations. Each descendant will implement a different strategy
   for distributing the algorithm across multiple devices/machines.
   Furthermore, these changes can be hidden inside the specific layers
   and other library classes that need special treatment to run in a
   distributed setting, so that most users' model definition code can
-  run unchanged. The `DistributionStrategy` API works the same way
+  run unchanged. The `tf.distribute.Strategy` API works the same way
   with eager and graph execution.
 
   First let's introduce a few high-level concepts:
@@ -696,42 +705,33 @@ class DistributionStrategyExtended(object):
 
   We have then a few approaches we want to support:
 
-  * Code written (as if) with no knowledge of class `DistributionStrategy`.
+  * Code written (as if) with no knowledge of class `tf.distribute.Strategy`.
     This code should work as before, even if some of the layers, etc.
     used by that code are written to be distribution-aware. This is done
-    by having a default `DistributionStrategy` that gives ordinary behavior,
+    by having a default `tf.distribute.Strategy` that gives ordinary behavior,
     and by default being in a single replica context.
   * Ordinary model code that you want to run using a specific
-    `DistributionStrategy`. This can be as simple as:
+    `tf.distribute.Strategy`. This can be as simple as:
 
     ```
-    with my_distribution.scope():
-      iterator = my_distribution.distribute_dataset(
-          dataset).make_one_shot_iterator()
-      replica_train_ops = my_distribution.extended.call_for_each_replica(
+    with my_strategy.scope():
+      iterator = my_strategy.make_dataset_iterator(dataset)
+      session.run(iterator.initialize())
+      replica_train_ops = my_strategy.extended.call_for_each_replica(
           replica_fn, args=(iterator.get_next(),))
-      train_op = tf.group(my_distribution.unwrap(replica_train_ops))
+      train_op = my_strategy.group(replica_train_ops)
     ```
 
     This takes an ordinary `dataset` and `replica_fn` and runs it
-    distributed using a particular `DistributionStrategy` in
-    `my_distribution`. Any variables created in `replica_fn` are created
-    using `my_distribution`'s policy, and library functions called by
+    distributed using a particular `tf.distribute.Strategy` in
+    `my_strategy`. Any variables created in `replica_fn` are created
+    using `my_strategy`'s policy, and library functions called by
     `replica_fn` can use the `get_replica_context()` API to get enhanced
     behavior in this case.
 
-    You can also create an initializable iterator instead of a one-shot
-    iterator. In that case, you will need to ensure that you initialize the
-    iterator before calling get_next.
-    ```
-    iterator = my_distribution.distribute_dataset(
-        dataset).make_initializable_iterator())
-    session.run(iterator.initializer)
-    ```
-
   * If you want to write a distributed algorithm, you may use any of
-    the `DistributionStrategy` APIs inside a
-    `with my_distribution.scope():` block of code.
+    the `tf.distribute.Strategy` APIs inside a
+    `with my_strategy.scope():` block of code.
 
   Lower-level concepts:
 
@@ -758,7 +758,7 @@ class DistributionStrategyExtended(object):
   * Replica context vs. Cross-replica context: _replica context_ is when we
     are in some function that is being called once for each replica.
     Otherwise we are in cross-replica context, which is useful for
-    calling `DistributionStrategy` methods which operate across the
+    calling `tf.distribute.Strategy` methods which operate across the
     replicas (like `reduce_to()`). By default you start in a replica context
     (the default "single replica context") and then some methods can
     switch you back and forth, as described below.
@@ -778,7 +778,7 @@ class DistributionStrategyExtended(object):
     pick a consistent set of devices to pass to both
     `colocate_vars_with()` and `update_non_slot()`.
 
-  When using a `DistributionStrategy`, we have a new type dimension
+  When using a `tf.distribute.Strategy`, we have a new type dimension
   called _locality_ that says what values are compatible with which
   APIs:
 
@@ -856,19 +856,20 @@ class DistributionStrategyExtended(object):
   input and destination.
 
   Layers should expect to be called in a replica context, and can use
-  the `get_replica_context()` function to get a `ReplicaContext` object. The
+  the `tf.distribute.get_replica_context` function to get a
+  `tf.distribute.ReplicaContext` object. The
   `ReplicaContext` object has a `merge_call()` method for entering
   cross-replica context where you can use `reduce_to()` (or
   `batch_reduce_to()`) and then optionally `update()` to update state.
 
-  You may use this API whether or not a `DistributionStrategy` is
+  You may use this API whether or not a `tf.distribute.Strategy` is
   being used, since there is a default implementation of
-  `ReplicaContext` and `DistributionStrategy`.
+  `ReplicaContext` and `tf.distribute.Strategy`.
 
-  NOTE for new `DistributionStrategy` implementations: Please put all logic
-  in a subclass of `DistributionStrategyExtended`. The only code needed for
-  the `DistributionStrategy` subclass is for instantiating your subclass of
-  `DistributionStrategyExtended` in the `__init__` method.
+  NOTE for new `tf.distribute.Strategy` implementations: Please put all logic
+  in a subclass of `tf.distribute.StrategyExtended`. The only code needed for
+  the `tf.distribute.Strategy` subclass is for instantiating your subclass of
+  `tf.distribute.StrategyExtended` in the `__init__` method.
   """
 
   def __init__(self, container_strategy):
@@ -908,7 +909,7 @@ class DistributionStrategyExtended(object):
         if kwargs.pop("partitioner", None) is not None:
           tf_logging.log_first_n(
               tf_logging.WARN, "Partitioned variables are disabled when using "
-              "current DistributionStrategy.", 1)
+              "current tf.distribute.Strategy.", 1)
       return getter(*args, **kwargs)
 
     return _CurrentDistributionContext(
@@ -932,7 +933,7 @@ class DistributionStrategyExtended(object):
     (read-only) value of any other variable.
 
     Args:
-      v: A variable allocated within the scope of this `DistributionStrategy`.
+      v: A variable allocated within the scope of this `tf.distribute.Strategy`.
 
     Returns:
       A tensor representing the value of `v`, aggregated across replicas if
@@ -953,9 +954,9 @@ class DistributionStrategyExtended(object):
     Example usage:
 
     ```
-    with distribution_strategy.scope():
+    with strategy.scope():
       var1 = tf.get_variable(...)
-      with distribution_strategy.colocate_vars_with(v1):
+      with strategy.extended.colocate_vars_with(v1):
         # var2 and var3 will be created on the same device(s) as var1
         var2 = tf.get_variable(...)
         var3 = tf.get_variable(...)
@@ -964,7 +965,7 @@ class DistributionStrategyExtended(object):
         # operates on v1 from var1, v2 from var2, and v3 from var3
 
       # `fn` runs on every device `v1` is on, `v2` and `v3` will be there too.
-      distribution_strategy.update(v1, fn, args=(v2, v3))
+      strategy.extended.update(v1, fn, args=(v2, v3))
     ```
 
     Args:
@@ -990,7 +991,7 @@ class DistributionStrategyExtended(object):
     if not isinstance(result, dataset_ops.Dataset):
       raise ValueError(
           "dataset_fn() must return a tf.data.Dataset when using a "
-          "DistributionStrategy.")
+          "tf.distribute.Strategy.")
     return result
 
   # TODO(josh11b): `PerReplicaDataset` currently only implements a few methods of
@@ -1320,7 +1321,7 @@ class DistributionStrategyExtended(object):
 
     Args:
       var_list: The list of variables being optimized, needed with the
-        default `DistributionStrategy`.
+        default `tf.distribute.Strategy`.
     """
     raise NotImplementedError("must be implemented in descendants")
 
@@ -1374,13 +1375,18 @@ class DistributionStrategyExtended(object):
 #   around their model creation and graph definition. There is no
 #   anticipated need to define descendants of _CurrentDistributionContext.
 #   It sets the current DistributionStrategy for purposes of
-#   `get_distribution_strategy()` and `has_distribution_strategy()`
+#   `get_strategy()` and `has_strategy()`
 #   and switches the thread mode to a "cross-replica context".
+@tf_export("distribute.ReplicaContext")
 class ReplicaContext(object):
-  """DistributionStrategy API inside a `call_for_each_replica()` call."""
+  """`tf.distribute.Strategy` API when in a replica context.
 
-  def __init__(self, distribution_strategy, replica_id_in_sync_group):
-    self._distribution_strategy = distribution_strategy
+  To be used inside your replicated step function, such as in a
+  `tf.distribute.StrategyExtended.call_for_each_replica` call.
+  """
+
+  def __init__(self, strategy, replica_id_in_sync_group):
+    self._distribution_strategy = strategy
     self._thread_context = distribution_strategy_context._InReplicaThreadMode(  # pylint: disable=protected-access
         self)
     self._replica_id_in_sync_group = replica_id_in_sync_group
@@ -1396,23 +1402,24 @@ class ReplicaContext(object):
 
     This allows communication and coordination when there are multiple calls
     to a model function triggered by a call to
-    `distribution.call_for_each_replica(model_fn, ...)`.
+    `strategy.extended.call_for_each_replica(model_fn, ...)`.
 
-    See `MirroredDistribution.call_for_each_replica()` for an explanation.
+    See `tf.distribute.StrategyExtended.call_for_each_replica` for an
+    explanation.
 
-    Otherwise, this is equivalent to:
+    If not inside a distributed scope, this is equivalent to:
 
     ```
-    distribution = get_distribution_strategy()
-    with cross-replica-context(distribution):
-      return merge_fn(distribution, *args, **kwargs)
+    strategy = tf.distribute.get_strategy()
+    with cross-replica-context(strategy):
+      return merge_fn(strategy, *args, **kwargs)
     ```
 
     Args:
       merge_fn: function that joins arguments from threads that are given as
-        PerReplica. It accepts `DistributionStrategy` object as the first
-        argument.
-      args: List or tuple with positional per-thread arguments for `merge_fn`
+        PerReplica. It accepts `tf.distribute.Strategy` object as
+        the first argument.
+      args: List or tuple with positional per-thread arguments for `merge_fn`.
       kwargs: Dict with keyword per-thread arguments for `merge_fn`.
 
     Returns:
@@ -1446,8 +1453,14 @@ class ReplicaContext(object):
     return self._replica_id_in_sync_group
 
   @property
+  @doc_controls.do_not_generate_docs  # DEPRECATED, use `strategy`
   def distribution_strategy(self):
-    """The current `DistributionStrategy` object."""
+    """DEPRECATED: use `self.stratgey` instead."""
+    return self._distribution_strategy
+
+  @property
+  def strategy(self):
+    """The current `tf.distribute.Strategy` object."""
     return self._distribution_strategy
 
   @property
@@ -1470,7 +1483,7 @@ class ReplicaContext(object):
 
 
 class _DefaultDistributionStrategy(DistributionStrategy):
-  """Default `DistributionStrategy` if none is explicitly selected."""
+  """Default `tf.distribute.Strategy` if none is explicitly selected."""
 
   def __init__(self):
     super(_DefaultDistributionStrategy, self).__init__(
@@ -1483,7 +1496,7 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
   def _scope(self, strategy):
     """Context manager setting a variable creator and `self` as current."""
     if distribution_strategy_context.has_distribution_strategy():
-      raise RuntimeError("Must not nest DistributionStrategy scopes.")
+      raise RuntimeError("Must not nest tf.distribute.Strategy scopes.")
 
     def creator(next_creator, *args, **kwargs):
       _require_distribution_strategy_scope_strategy(strategy)
@@ -1555,13 +1568,13 @@ class _DefaultDistributionExtended(DistributionStrategyExtended):
 
   @property
   def worker_devices(self):
-    raise RuntimeError(
-        "worker_devices() method unsupported by _DefaultDistributionStrategy.")
+    raise RuntimeError("worker_devices() method unsupported by default "
+                       "tf.distribute.Strategy.")
 
   @property
   def parameter_devices(self):
-    raise RuntimeError("parameter_devices() method unsupported by "
-                       "_DefaultDistributionStrategy.")
+    raise RuntimeError("parameter_devices() method unsupported by default "
+                       "tf.distribute.Strategy.")
 
   def non_slot_devices(self, var_list):
     return min(var_list, key=lambda x: x.name)
@@ -1600,8 +1613,8 @@ _original_from_proto = resource_variable_ops._from_proto_fn
 def _from_proto_fn(v, import_scope=None):
   if distribution_strategy_context.has_distribution_strategy():
     raise NotImplementedError(
-        "Deserialization of variables is not yet supported when using"
-        "distributed strategies.")
+        "Deserialization of variables is not yet supported when using a "
+        "tf.distribute.Strategy.")
   else:
     return _original_from_proto(v, import_scope=import_scope)
 
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
index c73d65d330e..0b3878de183 100644
--- a/tensorflow/python/training/distribution_strategy_context.py
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 from tensorflow.python.framework import ops
 from tensorflow.python.util.lazy_loader import LazyLoader
+from tensorflow.python.util.tf_export import tf_export
 
 
 # There is a circular dependency between this and `distribute` module. So we
@@ -85,6 +86,7 @@ def _get_per_thread_mode():
 # Public API for accessing the current thread mode
 
 
+@tf_export("distribute.get_replica_context")
 def get_replica_context():
   """Returns the current `tf.distribute.ReplicaContext` or `None`.
 
@@ -95,7 +97,7 @@ def get_replica_context():
   1. starts in the default (single-replica) replica context (this function
      will return the default `ReplicaContext` object);
   2. switches to cross-replica context (in which case this will return
-     `None`) when entering a `with DistributionStrategy.scope():` block;
+     `None`) when entering a `with tf.distribute.Strategy.scope():` block;
   3. switches to a (non-default) replica context inside
      `extended.call_for_each_replica(fn, ...)`;
   4. if `fn` calls `get_replica_context().merge_call(merge_fn, ...)`, then
@@ -103,11 +105,11 @@ def get_replica_context():
      this function will return `None`).
 
   Note that you can also go directly from step 1 to 4 to switch to a
-  cross-replica context for the default `DistributionStrategy`. You may
+  cross-replica context for the default `tf.distribute.Strategy`. You may
   also switch from the cross-replica context of 4 to a replica context by
   calling `extended.call_for_each_replica()`, jumping back to step 3.
 
-  Most `DistributionStrategy` methods may only be executed in
+  Most `tf.distribute.Strategy` methods may only be executed in
   a cross-replica context, in a replica context you should use the
   `ReplicaContext` API instead.
 
@@ -124,7 +126,7 @@ def get_replica_context():
 
 
 def get_cross_replica_context():
-  """Returns the current DistributionStrategy if in a cross-replica context.
+  """Returns the current tf.distribute.Strategy if in a cross-replica context.
 
   DEPRECATED: Please use `in_cross_replica_context()` and
   `get_distribution_strategy()` instead.
@@ -133,22 +135,22 @@ def get_cross_replica_context():
 
   1. starts in the default (single-replica) replica context;
   2. switches to cross-replica context when entering a
-     `with DistributionStrategy.scope():` block;
+     `with tf.distribute.Strategy.scope():` block;
   3. switches to a (non-default) replica context inside
      `call_for_each_replica(fn, ...)`;
   4. if `fn` calls `get_replica_context()->merge_call(merge_fn, ...)`, then
      inside `merge_fn` you are back in the cross-replica context.
 
   Note that you can also go directly from step 1 to 4 to switch to a
-  cross-replica context for the default `DistributionStrategy`. You may
+  cross-replica context for the default `tf.distribute.Strategy`. You may
   also switch from the cross-replica context of 4 to a replica context by
   calling `call_for_each_replica()`, jumping back to step 3.
 
-  Most `DistributionStrategy` methods may only be executed in
+  Most `tf.distribute.Strategy` methods may only be executed in
   a cross-replica context.
 
   Returns:
-    Returns the current `DistributionStrategy` object in a cross-replica
+    Returns the current `tf.distribute.Strategy` object in a cross-replica
     context, or `None`.
 
     Exactly one of `get_replica_context()` and `get_cross_replica_context()`
@@ -157,6 +159,7 @@ def get_cross_replica_context():
   return _get_per_thread_mode().cross_replica_context
 
 
+@tf_export("distribute.in_cross_replica_context")
 def in_cross_replica_context():
   """Returns True if in a cross-replica context.
 
@@ -170,31 +173,33 @@ def in_cross_replica_context():
   return _get_per_thread_mode().cross_replica_context is not None
 
 
+@tf_export("distribute.get_strategy")
 def get_distribution_strategy():
-  """Returns the current `DistributionStrategy` object.
+  """Returns the current `tf.distribute.Strategy` object.
 
   Typically only used in a cross-replica context:
 
   ```
   if tf.distribute.in_cross_replica_context():
-    strategy = tf.distribute.get_distribution_strategy()
+    strategy = tf.distribute.get_strategy()
     ...
   ```
 
   Returns:
-    A `DistributionStrategy` object. Inside a
+    A `tf.distribute.Strategy` object. Inside a
     `with distribution_strategy.scope()` block, it returns
     `distribution_strategy`, otherwise it returns the default
-    (single-replica) `DistributionStrategy` object.
+    (single-replica) `tf.distribute.Strategy` object.
   """
   return _get_per_thread_mode().distribution_strategy
 
 
+@tf_export("distribute.has_strategy")
 def has_distribution_strategy():
-  """Return if there is a current non-default `DistributionStrategy`.
+  """Return if there is a current non-default `tf.distribute.Strategy`.
 
   Returns:
-    True if inside a `with distribution_strategy.scope():`.
+    True if inside a `with strategy.scope():`.
   """
   return get_distribution_strategy() is not _get_default_distribution_strategy()
 
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt
new file mode 100644
index 00000000000..c39ac5a20de
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-context.pbtxt
@@ -0,0 +1,25 @@
+path: "tensorflow.distribute.InputContext"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "input_pipeline_id"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_input_pipelines"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_replicas_in_sync"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
+  }
+  member_method {
+    name: "get_per_replica_batch_size"
+    argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt
new file mode 100644
index 00000000000..6a7a3a97aa0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-input-replication-mode.pbtxt
@@ -0,0 +1,8 @@
+path: "tensorflow.distribute.InputReplicationMode"
+tf_class {
+  is_instance: "<enum \'InputReplicationMode\'>"
+  member {
+    name: "PER_WORKER"
+    mtype: "<enum \'InputReplicationMode\'>"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt
new file mode 100644
index 00000000000..4899f38cad2
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-reduce-op.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.distribute.ReduceOp"
+tf_class {
+  is_instance: "<enum \'ReduceOp\'>"
+  member {
+    name: "MEAN"
+    mtype: "<enum \'ReduceOp\'>"
+  }
+  member {
+    name: "SUM"
+    mtype: "<enum \'ReduceOp\'>"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
new file mode 100644
index 00000000000..3eda6c60366
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-replica-context.pbtxt
@@ -0,0 +1,33 @@
+path: "tensorflow.distribute.ReplicaContext"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "devices"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "distribution_strategy"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_replicas_in_sync"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "replica_id_in_sync_group"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "strategy"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "merge_call"
+    argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
new file mode 100644
index 00000000000..3b502b534be
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy-extended.pbtxt
@@ -0,0 +1,81 @@
+path: "tensorflow.distribute.StrategyExtended"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "experimental_between_graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "experimental_require_static_shapes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "experimental_should_init"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "parameter_devices"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_checkpoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_save_summary"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "worker_devices"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "batch_reduce_to"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "broadcast_to"
+    argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call_for_each_replica"
+    argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
+  }
+  member_method {
+    name: "colocate_vars_with"
+    argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "experimental_run_steps_on_iterator"
+    argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+  }
+  member_method {
+    name: "non_slot_devices"
+    argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "read_var"
+    argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reduce_to"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "update"
+    argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "update_non_slot"
+    argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "value_container"
+    argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt
new file mode 100644
index 00000000000..4fe035b474f
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.-strategy.pbtxt
@@ -0,0 +1,133 @@
+path: "tensorflow.distribute.Strategy"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "between_graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "extended"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_replicas_in_sync"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "parameter_devices"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "require_static_shapes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_checkpoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_init"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_save_summary"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "worker_devices"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "batch_reduce"
+    argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "broadcast"
+    argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "call_for_each_replica"
+    argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "colocate_vars_with"
+    argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "configure"
+    argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "distribute_dataset"
+    argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "experimental_finalize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "experimental_initialize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "finalize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "group"
+    argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "initialize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "make_dataset_iterator"
+    argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "make_input_fn_iterator"
+    argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
+  }
+  member_method {
+    name: "non_slot_devices"
+    argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "read_var"
+    argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reduce"
+    argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "run_steps_on_dataset"
+    argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+  }
+  member_method {
+    name: "scope"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "unwrap"
+    argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "update"
+    argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "update_non_slot"
+    argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "value_container"
+    argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt
new file mode 100644
index 00000000000..4d833b54ba0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.distribute.pbtxt
@@ -0,0 +1,47 @@
+path: "tensorflow.distribute"
+tf_module {
+  member {
+    name: "InputContext"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "InputReplicationMode"
+    mtype: "<class \'enum.EnumMeta\'>"
+  }
+  member {
+    name: "ReduceOp"
+    mtype: "<class \'enum.EnumMeta\'>"
+  }
+  member {
+    name: "ReplicaContext"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "Strategy"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "StrategyExtended"
+    mtype: "<type \'type\'>"
+  }
+  member_method {
+    name: "get_loss_reduction"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_replica_context"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_strategy"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "has_strategy"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "in_cross_replica_context"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 1bb0152f19e..6a45bc7b7f4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -324,6 +324,10 @@ tf_module {
     name: "debugging"
     mtype: "<type \'module\'>"
   }
+  member {
+    name: "distribute"
+    mtype: "<type \'module\'>"
+  }
   member {
     name: "distributions"
     mtype: "<type \'module\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt
new file mode 100644
index 00000000000..c39ac5a20de
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-context.pbtxt
@@ -0,0 +1,25 @@
+path: "tensorflow.distribute.InputContext"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.InputContext\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "input_pipeline_id"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_input_pipelines"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_replicas_in_sync"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'num_input_pipelines\', \'input_pipeline_id\', \'num_replicas_in_sync\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'1\'], "
+  }
+  member_method {
+    name: "get_per_replica_batch_size"
+    argspec: "args=[\'self\', \'global_batch_size\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt
new file mode 100644
index 00000000000..6a7a3a97aa0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-input-replication-mode.pbtxt
@@ -0,0 +1,8 @@
+path: "tensorflow.distribute.InputReplicationMode"
+tf_class {
+  is_instance: "<enum \'InputReplicationMode\'>"
+  member {
+    name: "PER_WORKER"
+    mtype: "<enum \'InputReplicationMode\'>"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt
new file mode 100644
index 00000000000..4899f38cad2
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-reduce-op.pbtxt
@@ -0,0 +1,12 @@
+path: "tensorflow.distribute.ReduceOp"
+tf_class {
+  is_instance: "<enum \'ReduceOp\'>"
+  member {
+    name: "MEAN"
+    mtype: "<enum \'ReduceOp\'>"
+  }
+  member {
+    name: "SUM"
+    mtype: "<enum \'ReduceOp\'>"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
new file mode 100644
index 00000000000..3eda6c60366
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-replica-context.pbtxt
@@ -0,0 +1,33 @@
+path: "tensorflow.distribute.ReplicaContext"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.ReplicaContext\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "devices"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "distribution_strategy"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_replicas_in_sync"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "replica_id_in_sync_group"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "strategy"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'strategy\', \'replica_id_in_sync_group\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "merge_call"
+    argspec: "args=[\'self\', \'merge_fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
new file mode 100644
index 00000000000..3b502b534be
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy-extended.pbtxt
@@ -0,0 +1,81 @@
+path: "tensorflow.distribute.StrategyExtended"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategyExtended\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "experimental_between_graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "experimental_require_static_shapes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "experimental_should_init"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "parameter_devices"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_checkpoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_save_summary"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "worker_devices"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'container_strategy\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "batch_reduce_to"
+    argspec: "args=[\'self\', \'reduce_op\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "broadcast_to"
+    argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "call_for_each_replica"
+    argspec: "args=[\'self\', \'fn\', \'args\', \'kwargs\'], varargs=None, keywords=None, defaults=[\'()\', \'None\'], "
+  }
+  member_method {
+    name: "colocate_vars_with"
+    argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "experimental_run_steps_on_iterator"
+    argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+  }
+  member_method {
+    name: "non_slot_devices"
+    argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "read_var"
+    argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reduce_to"
+    argspec: "args=[\'self\', \'reduce_op\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "update"
+    argspec: "args=[\'self\', \'var\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "update_non_slot"
+    argspec: "args=[\'self\', \'colocate_with\', \'fn\', \'args\', \'kwargs\', \'group\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'True\'], "
+  }
+  member_method {
+    name: "value_container"
+    argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt
new file mode 100644
index 00000000000..4fe035b474f
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.-strategy.pbtxt
@@ -0,0 +1,133 @@
+path: "tensorflow.distribute.Strategy"
+tf_class {
+  is_instance: "<class \'tensorflow.python.training.distribute.DistributionStrategy\'>"
+  is_instance: "<type \'object\'>"
+  member {
+    name: "between_graph"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "extended"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "num_replicas_in_sync"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "parameter_devices"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "require_static_shapes"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_checkpoint"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_init"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "should_save_summary"
+    mtype: "<type \'property\'>"
+  }
+  member {
+    name: "worker_devices"
+    mtype: "<type \'property\'>"
+  }
+  member_method {
+    name: "__init__"
+    argspec: "args=[\'self\', \'extended\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "batch_reduce"
+    argspec: "args=[\'self\', \'aggregation\', \'value_destination_pairs\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "broadcast"
+    argspec: "args=[\'self\', \'tensor\', \'destinations\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "call_for_each_replica"
+    argspec: "args=[\'self\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "colocate_vars_with"
+    argspec: "args=[\'self\', \'colocate_with_variable\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "configure"
+    argspec: "args=[\'self\', \'session_config\', \'cluster_spec\', \'task_type\', \'task_id\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+  }
+  member_method {
+    name: "distribute_dataset"
+    argspec: "args=[\'self\', \'dataset_fn\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "experimental_finalize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "experimental_initialize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "finalize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "group"
+    argspec: "args=[\'self\', \'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+  }
+  member_method {
+    name: "initialize"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "make_dataset_iterator"
+    argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "make_input_fn_iterator"
+    argspec: "args=[\'self\', \'input_fn\', \'replication_mode\'], varargs=None, keywords=None, defaults=[\'InputReplicationMode.PER_WORKER\'], "
+  }
+  member_method {
+    name: "non_slot_devices"
+    argspec: "args=[\'self\', \'var_list\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "read_var"
+    argspec: "args=[\'self\', \'v\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "reduce"
+    argspec: "args=[\'self\', \'aggregation\', \'value\', \'destinations\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "run_steps_on_dataset"
+    argspec: "args=[\'self\', \'fn\', \'iterator\', \'iterations\', \'initial_loop_values\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+  }
+  member_method {
+    name: "scope"
+    argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "unwrap"
+    argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "update"
+    argspec: "args=[\'self\', \'var\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "update_non_slot"
+    argspec: "args=[\'self\', \'colocate_with\', \'fn\'], varargs=args, keywords=kwargs, defaults=None"
+  }
+  member_method {
+    name: "value_container"
+    argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
new file mode 100644
index 00000000000..4d833b54ba0
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
@@ -0,0 +1,47 @@
+path: "tensorflow.distribute"
+tf_module {
+  member {
+    name: "InputContext"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "InputReplicationMode"
+    mtype: "<class \'enum.EnumMeta\'>"
+  }
+  member {
+    name: "ReduceOp"
+    mtype: "<class \'enum.EnumMeta\'>"
+  }
+  member {
+    name: "ReplicaContext"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "Strategy"
+    mtype: "<type \'type\'>"
+  }
+  member {
+    name: "StrategyExtended"
+    mtype: "<type \'type\'>"
+  }
+  member_method {
+    name: "get_loss_reduction"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_replica_context"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "get_strategy"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "has_strategy"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+  member_method {
+    name: "in_cross_replica_context"
+    argspec: "args=[], varargs=None, keywords=None, defaults=None"
+  }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index 2d2ab7d0d36..10aa8b781fa 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -180,6 +180,10 @@ tf_module {
     name: "debugging"
     mtype: "<type \'module\'>"
   }
+  member {
+    name: "distribute"
+    mtype: "<type \'module\'>"
+  }
   member {
     name: "double"
     mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"