diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index b0d14e1bab3..e797d708032 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -299,6 +299,15 @@ py_library(
     ],
 )
 
+py_library(
+    name = "tpu_util",
+    srcs = ["tpu_util.py"],
+    deps = [
+        "//tensorflow/python:framework_ops",
+        "//tensorflow/python/tpu:tpu_py",
+    ],
+)
+
 py_library(
     name = "mirrored_strategy",
     srcs = ["mirrored_strategy.py"],
@@ -579,6 +588,7 @@ py_library(
         ":input_lib",
         ":numpy_dataset",
         ":reduce_util",
+        ":tpu_util",
         ":tpu_values",
         ":values",
         "//tensorflow/compiler/xla/experimental/xla_sharding",
@@ -809,6 +819,7 @@ py_library(
     srcs_version = "PY3",
     deps = [
         ":packed_distributed_variable",
+        ":tpu_util",
         ":values",
         ":values_util",
         "//tensorflow/python:framework_ops",
@@ -816,7 +827,6 @@ py_library(
         "//tensorflow/python:resource_variable_ops_gen",
         "//tensorflow/python/eager:context",
         "//tensorflow/python/eager:tape",
-        "//tensorflow/python/tpu:tpu_py",
     ],
 )
 
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index 4d6b0a204e2..b611b2a8cd8 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -38,6 +38,7 @@ from tensorflow.python.distribute import distribute_utils
 from tensorflow.python.distribute import input_lib
 from tensorflow.python.distribute import numpy_dataset
 from tensorflow.python.distribute import reduce_util
+from tensorflow.python.distribute import tpu_util
 from tensorflow.python.distribute import tpu_values
 from tensorflow.python.distribute import values
 from tensorflow.python.distribute.cluster_resolver import TPUClusterResolver
@@ -1099,7 +1100,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 
     self._logical_device_stack.append(logical_device_id)
     try:
-      if tpu_values.enclosing_tpu_context() is None:
+      if tpu_util.enclosing_tpu_context() is None:
         yield
       else:
         with ops.device(tpu.core(logical_device_id)):
@@ -1213,7 +1214,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
   def _reduce_to(self, reduce_op, value, destinations, options):
     if (isinstance(value, values.DistributedValues) or
         tensor_util.is_tf_type(value)
-       ) and tpu_values.enclosing_tpu_context() is not None:
+       ) and tpu_util.enclosing_tpu_context() is not None:
       if reduce_op == reduce_util.ReduceOp.MEAN:
         # TODO(jhseu):  Revisit once we support model-parallelism.
         value *= (1. / self._num_replicas_in_sync)
@@ -1260,7 +1261,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
   def _update(self, var, fn, args, kwargs, group):
     assert isinstance(var, tpu_values.TPUVariableMixin) or isinstance(
         var, resource_variable_ops.BaseResourceVariable)
-    if tpu_values.enclosing_tpu_context() is not None:
+    if tpu_util.enclosing_tpu_context() is not None:
       if group:
         return fn(var, *args, **kwargs)
       else:
@@ -1317,7 +1318,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
     # since the `1` gets broadcast as an int32 but global_step is int64.
     if isinstance(tensor, (float, int)):
       return tensor
-    if tpu_values.enclosing_tpu_context() is not None:
+    if tpu_util.enclosing_tpu_context() is not None:
       broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
       result = tpu_ops.all_to_all(
           broadcast_tensor,
diff --git a/tensorflow/python/distribute/tpu_util.py b/tensorflow/python/distribute/tpu_util.py
new file mode 100644
index 00000000000..4a8e2d35bab
--- /dev/null
+++ b/tensorflow/python/distribute/tpu_util.py
@@ -0,0 +1,35 @@
+# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for TPU."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.tpu import tpu
+
+
+def enclosing_tpu_context():
+  """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
+  graph = ops.get_default_graph()
+  while graph is not None:
+    # pylint: disable=protected-access
+    context_ = graph._get_control_flow_context()
+    # pylint: enable=protected-access
+    while context_ is not None:
+      if isinstance(context_, tpu.TPUReplicateContext):
+        return context_
+      context_ = context_.outer_context
+    # This may be a FuncGraph due to defuns or v2 control flow. We need to
+    # find the original graph with the XLAControlFlowContext.
+    graph = getattr(graph, "outer_graph", None)
+  return None
diff --git a/tensorflow/python/distribute/tpu_values.py b/tensorflow/python/distribute/tpu_values.py
index dbe1e1ff363..34ed82251e9 100644
--- a/tensorflow/python/distribute/tpu_values.py
+++ b/tensorflow/python/distribute/tpu_values.py
@@ -25,6 +25,7 @@ from __future__ import print_function
 import contextlib
 
 from tensorflow.python.distribute import packed_distributed_variable as packed
+from tensorflow.python.distribute import tpu_util
 from tensorflow.python.distribute import values
 from tensorflow.python.distribute import values_util
 from tensorflow.python.eager import context
@@ -33,7 +34,6 @@ from tensorflow.python.framework import ops
 from tensorflow.python.ops import gen_resource_variable_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import variable_scope
-from tensorflow.python.tpu import tpu
 
 
 @contextlib.contextmanager
@@ -66,9 +66,7 @@ def _make_raw_assign_fn(raw_assign_fn):  # pylint: disable=missing-docstring
     handle = var.handle
     with _maybe_enter_graph(handle), _maybe_on_device(var):
       op = raw_assign_fn(
-          handle,
-          ops.convert_to_tensor(value, dtype=var.dtype),
-          name=name)
+          handle, ops.convert_to_tensor(value, dtype=var.dtype), name=name)
       with ops.control_dependencies([op]):
         return var._read_variable_op() if read_value else op  # pylint: disable=protected-access
 
@@ -89,14 +87,14 @@ class TPUVariableMixin(object):
       self._handle_id = self._common_name
 
   def __getattr__(self, name):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUVariableMixin, self).__getattr__(name)
     else:
       raise AttributeError(
           "'{}' not accessible within a TPU context.".format(name))
 
   def get(self):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUVariableMixin, self).get()
     else:
       raise NotImplementedError(
@@ -113,7 +111,7 @@ class TPUVariableMixin(object):
   def handle(self):
     """The handle by which this variable can be accessed."""
     # If we're in a tpu.rewrite(), return the replicated handle.
-    tpu_context = enclosing_tpu_context()
+    tpu_context = tpu_util.enclosing_tpu_context()
     if tpu_context is None or context.executing_eagerly():
       var = self._get_on_device_or_primary()
       if isinstance(var, packed.PackedVarAndDevice):
@@ -148,19 +146,19 @@ class TPUVariableMixin(object):
       return gen_resource_variable_ops.read_variable_op(handle, self.dtype)
 
   def read_value(self):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUVariableMixin, self).read_value()
     else:
       return self._read_variable_op()
 
   def value(self):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUVariableMixin, self).value()
     else:
       return self._read_variable_op()
 
   def _as_graph_element(self):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUVariableMixin, self)._as_graph_element()  # pylint: disable=protected-access
     else:
       return None
@@ -177,7 +175,7 @@ class TPUVariableMixin(object):
   def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
     """Converts a variable to a tensor."""
     # pylint: disable=protected-access
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUVariableMixin, self)._dense_var_to_tensor(
           dtype=dtype, name=name, as_ref=as_ref)
     # pylint: enable=protected-access
@@ -187,23 +185,6 @@ class TPUVariableMixin(object):
       return self.handle if as_ref else self.read_value()
 
 
-def enclosing_tpu_context():
-  """Returns the TPUReplicateContext, which exists inside a tpu.rewrite()."""
-  graph = ops.get_default_graph()
-  while graph is not None:
-    # pylint: disable=protected-access
-    context_ = graph._get_control_flow_context()
-    # pylint: enable=protected-access
-    while context_ is not None:
-      if isinstance(context_, tpu.TPUReplicateContext):
-        return context_
-      context_ = context_.outer_context
-    # This may be a FuncGraph due to defuns or v2 control flow. We need to
-    # find the original graph with the XLAControlFlowContext.
-    graph = getattr(graph, "outer_graph", None)
-  return None
-
-
 class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
   """DistributedVariable subclass for TPUStrategy."""
 
@@ -274,9 +255,8 @@ class TPUDistributedVariable(TPUVariableMixin, values.DistributedVariable):
 class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
   """Holds a map from replica to TPU variables whose values are kept in sync."""
 
-  def assign_sub(self, value, use_locking=False, name=None,
-                 read_value=True):
-    if (enclosing_tpu_context() and
+  def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+    if (tpu_util.enclosing_tpu_context() and
         self.aggregation == variable_scope.VariableAggregation.NONE):
       return _make_raw_assign_fn(
           gen_resource_variable_ops.assign_sub_variable_op)(
@@ -285,12 +265,11 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
               use_locking=use_locking,
               name=name,
               read_value=read_value)
-    return assign_sub(self, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
+    return assign_sub(
+        self, value, use_locking=use_locking, name=name, read_value=read_value)
 
-  def assign_add(self, value, use_locking=False, name=None,
-                 read_value=True):
-    if (enclosing_tpu_context() and
+  def assign_add(self, value, use_locking=False, name=None, read_value=True):
+    if (tpu_util.enclosing_tpu_context() and
         self.aggregation == variable_scope.VariableAggregation.NONE):
       return _make_raw_assign_fn(
           gen_resource_variable_ops.assign_add_variable_op)(
@@ -299,21 +278,20 @@ class TPUMirroredVariable(TPUVariableMixin, values.MirroredVariable):
               use_locking=use_locking,
               name=name,
               read_value=read_value)
-    return assign_add(self, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
+    return assign_add(
+        self, value, use_locking=use_locking, name=name, read_value=read_value)
 
   def assign(self, value, use_locking=False, name=None, read_value=True):
-    if (enclosing_tpu_context() and
+    if (tpu_util.enclosing_tpu_context() and
         self.aggregation == variable_scope.VariableAggregation.NONE):
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_variable_op)(
-              self,
-              value=value,
-              use_locking=use_locking,
-              name=name,
-              read_value=read_value)
-    return assign(self, value, use_locking=use_locking, name=name,
-                  read_value=read_value)
+      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
+          self,
+          value=value,
+          use_locking=use_locking,
+          name=name,
+          read_value=read_value)
+    return assign(
+        self, value, use_locking=use_locking, name=name, read_value=read_value)
 
   def scatter_sub(self, *args, **kwargs):
     if values_util.is_saving_non_distributed():
@@ -358,7 +336,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
   """Holds a map from replica to variables whose values are reduced on save."""
 
   def assign_sub(self, *args, **kwargs):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return values.SyncOnReadVariable.assign_sub(self, *args, **kwargs)
     else:
       return _make_raw_assign_fn(
@@ -366,7 +344,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
                                                             **kwargs)
 
   def assign_add(self, *args, **kwargs):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return values.SyncOnReadVariable.assign_add(self, *args, **kwargs)
     else:
       return _make_raw_assign_fn(
@@ -374,7 +352,7 @@ class TPUSyncOnReadVariable(TPUVariableMixin, values.SyncOnReadVariable):
                                                             **kwargs)
 
   def assign(self, *args, **kwargs):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return values.SyncOnReadVariable.assign(self, *args, **kwargs)
     else:
       return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
@@ -408,8 +386,7 @@ def assign_add(var, value, use_locking=False, name=None, read_value=True):
 
 
 def assign(var, value, use_locking=False, name=None, read_value=True):
-  assign_fn = _make_raw_assign_fn(
-      gen_resource_variable_ops.assign_variable_op)
+  assign_fn = _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)
   return var._update(  # pylint: disable=protected-access
       update_fn=assign_fn,
       value=value,
@@ -427,9 +404,13 @@ class TPUAutoPolicy(values.AutoPolicy):
   scope.
   """
 
-  def assign_sub(self, var, value, use_locking=False, name=None,
+  def assign_sub(self,
+                 var,
+                 value,
+                 use_locking=False,
+                 name=None,
                  read_value=True):
-    if enclosing_tpu_context():
+    if tpu_util.enclosing_tpu_context():
       return _make_raw_assign_fn(
           gen_resource_variable_ops.assign_sub_variable_op)(
               var,
@@ -437,12 +418,16 @@ class TPUAutoPolicy(values.AutoPolicy):
               use_locking=use_locking,
               name=name,
               read_value=read_value)
-    return assign_sub(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
+    return assign_sub(
+        var, value, use_locking=use_locking, name=name, read_value=read_value)
 
-  def assign_add(self, var, value, use_locking=False, name=None,
+  def assign_add(self,
+                 var,
+                 value,
+                 use_locking=False,
+                 name=None,
                  read_value=True):
-    if enclosing_tpu_context():
+    if tpu_util.enclosing_tpu_context():
       return _make_raw_assign_fn(
           gen_resource_variable_ops.assign_add_variable_op)(
               var,
@@ -450,20 +435,19 @@ class TPUAutoPolicy(values.AutoPolicy):
               use_locking=use_locking,
               name=name,
               read_value=read_value)
-    return assign_add(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
+    return assign_add(
+        var, value, use_locking=use_locking, name=name, read_value=read_value)
 
   def assign(self, var, value, use_locking=False, name=None, read_value=True):
-    if enclosing_tpu_context():
-      return _make_raw_assign_fn(
-          gen_resource_variable_ops.assign_variable_op)(
-              var,
-              value=value,
-              use_locking=use_locking,
-              name=name,
-              read_value=read_value)
-    return assign(var, value, use_locking=use_locking, name=name,
-                  read_value=read_value)
+    if tpu_util.enclosing_tpu_context():
+      return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(
+          var,
+          value=value,
+          use_locking=use_locking,
+          name=name,
+          read_value=read_value)
+    return assign(
+        var, value, use_locking=use_locking, name=name, read_value=read_value)
 
   def scatter_sub(self, *args, **kwargs):
     raise NotImplementedError
@@ -504,19 +488,27 @@ class TPUOnWritePolicy(values.OnWritePolicy):
   values such as `NONE`, `SUM`, `MEAN` or `ONLY_FIRST_REPLICA`.
   """
 
-  def assign_sub(self, var, value, use_locking=False, name=None,
+  def assign_sub(self,
+                 var,
+                 value,
+                 use_locking=False,
+                 name=None,
                  read_value=True):
-    return assign_sub(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
+    return assign_sub(
+        var, value, use_locking=use_locking, name=name, read_value=read_value)
 
-  def assign_add(self, var, value, use_locking=False, name=None,
+  def assign_add(self,
+                 var,
+                 value,
+                 use_locking=False,
+                 name=None,
                  read_value=True):
-    return assign_add(var, value, use_locking=use_locking, name=name,
-                      read_value=read_value)
+    return assign_add(
+        var, value, use_locking=use_locking, name=name, read_value=read_value)
 
   def assign(self, var, value, use_locking=False, name=None, read_value=True):
-    return assign(var, value, use_locking=use_locking, name=name,
-                  read_value=read_value)
+    return assign(
+        var, value, use_locking=use_locking, name=name, read_value=read_value)
 
   def scatter_sub(self, *args, **kwargs):
     raise NotImplementedError
@@ -554,7 +546,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
   """
 
   def assign_sub(self, var, *args, **kwargs):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUOnReadPolicy, self).assign_sub(var, *args, **kwargs)
     else:
       return _make_raw_assign_fn(
@@ -562,7 +554,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
                                                             **kwargs)
 
   def assign_add(self, var, *args, **kwargs):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUOnReadPolicy, self).assign_add(var, *args, **kwargs)
     else:
       return _make_raw_assign_fn(
@@ -570,7 +562,7 @@ class TPUOnReadPolicy(values.OnReadPolicy):
                                                             **kwargs)
 
   def assign(self, var, *args, **kwargs):
-    if enclosing_tpu_context() is None:
+    if tpu_util.enclosing_tpu_context() is None:
       return super(TPUOnReadPolicy, self).assign(var, *args, **kwargs)
     else:
       return _make_raw_assign_fn(gen_resource_variable_ops.assign_variable_op)(