diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 7ffab1df9a1..f532315e1ef 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -1369,7 +1369,8 @@ class StrategyExtendedV2(object):
 
   def _scope(self, strategy):
     """Implementation of tf.distribute.Strategy.scope()."""
-    def creator_with_resource_vars(*args, **kwargs):
+
+    def creator_with_resource_vars(next_creator, **kwargs):
       """Variable creator to use in `_CurrentDistributionContext`."""
       _require_strategy_scope_extended(self)
       kwargs["use_resource"] = True
@@ -1382,7 +1383,7 @@ class StrategyExtendedV2(object):
       if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
         kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
 
-      return self._create_variable(*args, **kwargs)
+      return self._create_variable(next_creator, **kwargs)
 
     def distributed_getter(getter, *args, **kwargs):
       if not self._allow_variable_partition():
@@ -1402,7 +1403,7 @@ class StrategyExtendedV2(object):
   def _allow_variable_partition(self):
     return False
 
-  def _create_variable(self, next_creator, *args, **kwargs):
+  def _create_variable(self, next_creator, **kwargs):
     # Note: should support "colocate_with" argument.
     raise NotImplementedError("must be implemented in descendants")
 
@@ -1471,11 +1472,12 @@ class StrategyExtendedV2(object):
     Returns:
       A context manager.
     """
-    def create_colocated_variable(next_creator, *args, **kwargs):
+
+    def create_colocated_variable(next_creator, **kwargs):
       _require_strategy_scope_extended(self)
       kwargs["use_resource"] = True
       kwargs["colocate_with"] = colocate_with_variable
-      return next_creator(*args, **kwargs)
+      return next_creator(**kwargs)
 
     _require_strategy_scope_extended(self)
     self._validate_colocate_with_variable(colocate_with_variable)
@@ -2139,9 +2141,9 @@ class _DefaultDistributionContext(object):
 
   def __init__(self, strategy):
 
-    def creator(next_creator, *args, **kwargs):
+    def creator(next_creator, **kwargs):
       _require_strategy_scope_strategy(strategy)
-      return next_creator(*args, **kwargs)
+      return next_creator(**kwargs)
 
     self._var_creator_scope = variable_scope.variable_creator_scope(creator)
     self._strategy = strategy
diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py
index 1d171bc5cd5..8c7ad0ae40d 100644
--- a/tensorflow/python/distribute/distribute_lib_test.py
+++ b/tensorflow/python/distribute/distribute_lib_test.py
@@ -77,7 +77,7 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
         replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
       return fn(*args, **kwargs)
 
-  def _create_variable(self, next_creator, *args, **kwargs):
+  def _create_variable(self, next_creator, **kwargs):
     return _get_test_variable(kwargs["name"], kwargs["synchronization"],
                               kwargs["aggregation"])
 
@@ -432,8 +432,8 @@ class _TestStrategy2(distribute_lib.Strategy):
 
 class _TestExtended2(_TestExtended):
 
-  def _create_variable(self, next_creator, *args, **kwargs):
-    return next_creator(*args, **kwargs)
+  def _create_variable(self, next_creator, **kwargs):
+    return next_creator(**kwargs)
 
 
 class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
diff --git a/tensorflow/python/distribute/minimize_loss_test.py b/tensorflow/python/distribute/minimize_loss_test.py
index 743614b096f..fb9aa61aa3f 100644
--- a/tensorflow/python/distribute/minimize_loss_test.py
+++ b/tensorflow/python/distribute/minimize_loss_test.py
@@ -172,8 +172,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
     created_variables = []
     trainable_variables = []
 
-    def appending_creator(next_creator, *args, **kwargs):
-      v = next_creator(*args, **kwargs)
+    def appending_creator(next_creator, **kwargs):
+      v = next_creator(**kwargs)
       created_variables.append(v.name)
       if "trainable" in kwargs and kwargs["trainable"]:
         trainable_variables.append(v.name)
diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py
index 9e26421b824..22c4ba62afa 100644
--- a/tensorflow/python/distribute/mirrored_strategy.py
+++ b/tensorflow/python/distribute/mirrored_strategy.py
@@ -569,18 +569,18 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
 
       return initial_value_fn
 
-  def _create_variable(self, next_creator, *args, **kwargs):
+  def _create_variable(self, next_creator, **kwargs):
     """Create a mirrored variable. See `DistributionStrategy.scope`."""
     colocate_with = kwargs.pop("colocate_with", None)
     if colocate_with is None:
       devices = self._devices
     elif isinstance(colocate_with, numpy_dataset.SingleDevice):
       with ops.device(colocate_with.device):
-        return next_creator(*args, **kwargs)
+        return next_creator(**kwargs)
     else:
       devices = colocate_with.devices
 
-    def _real_mirrored_creator(*args, **kwargs):  # pylint: disable=g-missing-docstring
+    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
       value_list = []
       for i, d in enumerate(devices):
         with ops.device(d):
@@ -600,14 +600,15 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
             # Don't record operations (e.g. other variable reads) during
             # variable creation.
             with tape.stop_recording():
-              v = next_creator(*args, **kwargs)
+              v = next_creator(**kwargs)
           assert not isinstance(v, values.DistributedVariable)
           value_list.append(v)
       return value_list
 
-    return values.create_mirrored_variable(
-        self._container_strategy(), _real_mirrored_creator,
-        values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs)
+    return values.create_mirrored_variable(self._container_strategy(),
+                                           _real_mirrored_creator,
+                                           values.MirroredVariable,
+                                           values.SyncOnReadVariable, **kwargs)
 
   def _validate_colocate_with_variable(self, colocate_with_variable):
     values.validate_colocate_distributed_variable(colocate_with_variable, self)
diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py
index 29bf69d64ec..9446d898126 100644
--- a/tensorflow/python/distribute/mirrored_strategy_test.py
+++ b/tensorflow/python/distribute/mirrored_strategy_test.py
@@ -293,8 +293,8 @@ class MirroredStrategyVariableCreatorStackTest(
     def model_fn():
       replica_id_str = str(self.evaluate(_replica_id()))
 
-      def thread_creator_fn(next_creator, *args, **kwargs):
-        return next_creator(*args, **kwargs) + ":thread_" + replica_id_str
+      def thread_creator_fn(next_creator, **kwargs):
+        return next_creator(**kwargs) + ":thread_" + replica_id_str
 
       with variable_scope.variable_creator_scope(thread_creator_fn):
         # Create a variable in this scope.
@@ -304,9 +304,9 @@ class MirroredStrategyVariableCreatorStackTest(
         ds_context.get_replica_context().merge_call(lambda _: _)
       return v
 
-    def main_thread_creator(next_creator, *args, **kwargs):
+    def main_thread_creator(next_creator, **kwargs):
       # We are not using the underlying next_creator for test purposes.
-      del next_creator, args, kwargs
+      del next_creator, kwargs
       return "main_thread"
 
     with context.graph_mode(), \
diff --git a/tensorflow/python/distribute/numpy_dataset.py b/tensorflow/python/distribute/numpy_dataset.py
index 5881e4cd59e..0d8df03f88c 100644
--- a/tensorflow/python/distribute/numpy_dataset.py
+++ b/tensorflow/python/distribute/numpy_dataset.py
@@ -75,9 +75,10 @@ def init_var_from_numpy(input_var, numpy_input, session):
 
 def one_host_numpy_dataset(numpy_input, colocate_with, session):
   """Create a dataset on `colocate_with` from `numpy_input`."""
-  def create_colocated_variable(next_creator, *args, **kwargs):
+
+  def create_colocated_variable(next_creator, **kwargs):
     kwargs["colocate_with"] = colocate_with
-    return next_creator(*args, **kwargs)
+    return next_creator(**kwargs)
 
   numpy_flat = nest.flatten(numpy_input)
   with variable_scope.variable_creator_scope(create_colocated_variable):
diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py
index 144ce6a8fce..2e52cfb457a 100644
--- a/tensorflow/python/distribute/one_device_strategy.py
+++ b/tensorflow/python/distribute/one_device_strategy.py
@@ -253,17 +253,17 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
     worker_device_pairs = [(self._input_device, [self._device])]
     self._input_workers = input_lib.InputWorkers(worker_device_pairs)
 
-  def _create_variable(self, next_creator, *args, **kwargs):
+  def _create_variable(self, next_creator, **kwargs):
     colocate_with = kwargs.pop("colocate_with", None)
     if colocate_with is None:
       with ops.device(self._device):
-        return next_creator(*args, **kwargs)
+        return next_creator(**kwargs)
     elif isinstance(colocate_with, numpy_dataset.SingleDevice):
       with ops.device(colocate_with.device):
-        return next_creator(*args, **kwargs)
+        return next_creator(**kwargs)
     else:
       with ops.colocate_with(colocate_with):
-        return next_creator(*args, **kwargs)
+        return next_creator(**kwargs)
 
   def _validate_colocate_with_variable(self, colocate_with_variable):
     values.validate_colocate(colocate_with_variable, self)
diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py
index 900b6a5b453..72d25db474b 100644
--- a/tensorflow/python/distribute/parameter_server_strategy.py
+++ b/tensorflow/python/distribute/parameter_server_strategy.py
@@ -388,7 +388,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
 
   # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through
   # this creator, such as "MutableHashTable".
-  def _create_variable(self, next_creator, *args, **kwargs):
+  def _create_variable(self, next_creator, **kwargs):
     if self._num_replicas_in_sync > 1:
       aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
       if aggregation not in (
@@ -400,7 +400,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
         raise ValueError("Invalid variable aggregation mode: " + aggregation +
                          " for variable: " + kwargs["name"])
 
-      def var_creator(*args, **kwargs):
+      def var_creator(**kwargs):
         """Create an AggregatingVariable and fix up collections."""
         # Record what collections this variable should be added to.
         collections = kwargs.pop("collections", None)
@@ -409,7 +409,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
         kwargs["collections"] = []
 
         # Create and wrap the variable.
-        v = next_creator(*args, **kwargs)
+        v = next_creator(**kwargs)
         wrapped = values.AggregatingVariable(
             self._container_strategy(), v, aggregation)
 
@@ -440,14 +440,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
       colocate_with = kwargs["colocate_with"]
       if isinstance(colocate_with, numpy_dataset.SingleDevice):
         with ops.device(colocate_with.device):
-          return var_creator(*args, **kwargs)
+          return var_creator(**kwargs)
       with ops.device(None):
         with ops.colocate_with(colocate_with):
-          return var_creator(*args, **kwargs)
+          return var_creator(**kwargs)
 
     with ops.colocate_with(None, ignore_existing=True):
       with ops.device(self._variable_device):
-        return var_creator(*args, **kwargs)
+        return var_creator(**kwargs)
 
   def _call_for_each_replica(self, fn, args, kwargs):
     # pylint: disable=protected-access
diff --git a/tensorflow/python/distribute/shared_variable_creator.py b/tensorflow/python/distribute/shared_variable_creator.py
index a7083e279f2..11ed271bf4c 100644
--- a/tensorflow/python/distribute/shared_variable_creator.py
+++ b/tensorflow/python/distribute/shared_variable_creator.py
@@ -63,19 +63,19 @@ def make_fn(shared_variable_store, device_id):
   variable_scope_access_index = {}
   assert isinstance(device_id, int)
 
-  def create_new_variable(next_creator, *args, **kwargs):
+  def create_new_variable(next_creator, **kwargs):
     """Create the variable using `next_creator` and store it."""
     canonical_name = _canonicalize_variable_name(kwargs.get("name"))
-    v = next_creator(*args, **kwargs)
+    v = next_creator(**kwargs)
 
     if canonical_name not in shared_variable_store:
       shared_variable_store[canonical_name] = []
     shared_variable_store[canonical_name].append(v)
     return v
 
-  def reuse_variable(next_creator, *args, **kwargs):
+  def reuse_variable(next_creator, **kwargs):
     """Re-use existing variable from store with same name (in order)."""
-    del next_creator, args
+    del next_creator
     name = kwargs.get("name")
     canonical_name = _canonicalize_variable_name(name)
 
diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py
index e7335e23c9a..0a127ca5167 100644
--- a/tensorflow/python/distribute/tpu_strategy.py
+++ b/tensorflow/python/distribute/tpu_strategy.py
@@ -508,21 +508,21 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
     """
     tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
 
-  def _create_variable(self, next_creator, *args, **kwargs):
+  def _create_variable(self, next_creator, **kwargs):
     """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
     if kwargs.pop("skip_mirrored_creator", False):
-      return next_creator(*args, **kwargs)
+      return next_creator(**kwargs)
 
     colocate_with = kwargs.pop("colocate_with", None)
     if colocate_with is None:
       devices = self._tpu_devices[:, self._logical_device_stack[-1]]
     elif isinstance(colocate_with, numpy_dataset.SingleDevice):
       with ops.device(colocate_with.device):
-        return next_creator(*args, **kwargs)
+        return next_creator(**kwargs)
     else:
       devices = colocate_with.devices
 
-    def _real_mirrored_creator(*args, **kwargs):  # pylint: disable=g-missing-docstring
+    def _real_mirrored_creator(**kwargs):  # pylint: disable=g-missing-docstring
       initial_value = None
       value_list = []
       for i, d in enumerate(devices):
@@ -545,18 +545,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
           kwargs["initial_value"] = initial_value
 
           with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
-            v = next_creator(*args, **kwargs)
+            v = next_creator(**kwargs)
 
           assert not isinstance(v, values.TPUMirroredVariable)
           value_list.append(v)
       return value_list
 
-    return values.create_mirrored_variable(
-        self._container_strategy(),
-        _real_mirrored_creator,
-        values.TPUMirroredVariable,
-        values.TPUSyncOnReadVariable,
-        *args, **kwargs)
+    return values.create_mirrored_variable(self._container_strategy(),
+                                           _real_mirrored_creator,
+                                           values.TPUMirroredVariable,
+                                           values.TPUSyncOnReadVariable,
+                                           **kwargs)
 
   def _reduce_to(self, reduce_op, value, destinations):
     if (isinstance(value, values.DistributedValues) or
diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py
index 7c7a3b5505f..0126df3ae51 100644
--- a/tensorflow/python/distribute/values.py
+++ b/tensorflow/python/distribute/values.py
@@ -728,8 +728,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
 
 
 def create_mirrored_variable(  # pylint: disable=missing-docstring
-    strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls,
-    *args, **kwargs):
+    strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
   # Figure out what collections this variable should be added to.
   # We'll add the MirroredVariable to those collections instead.
   var_collections = kwargs.pop("collections", None)
@@ -772,7 +771,7 @@ def create_mirrored_variable(  # pylint: disable=missing-docstring
   # was never recorded on the tape instead of having to do this manually
   # here.
   with tape.stop_recording():
-    value_list = real_mirrored_creator(*args, **kwargs)
+    value_list = real_mirrored_creator(**kwargs)
     var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
     result = var_cls(strategy, value_list, aggregation)