diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 7ffab1df9a1..f532315e1ef 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -1369,7 +1369,8 @@ class StrategyExtendedV2(object): def _scope(self, strategy): """Implementation of tf.distribute.Strategy.scope().""" - def creator_with_resource_vars(*args, **kwargs): + + def creator_with_resource_vars(next_creator, **kwargs): """Variable creator to use in `_CurrentDistributionContext`.""" _require_strategy_scope_extended(self) kwargs["use_resource"] = True @@ -1382,7 +1383,7 @@ class StrategyExtendedV2(object): if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue): kwargs["initial_value"] = kwargs["initial_value"].wrapped_value - return self._create_variable(*args, **kwargs) + return self._create_variable(next_creator, **kwargs) def distributed_getter(getter, *args, **kwargs): if not self._allow_variable_partition(): @@ -1402,7 +1403,7 @@ class StrategyExtendedV2(object): def _allow_variable_partition(self): return False - def _create_variable(self, next_creator, *args, **kwargs): + def _create_variable(self, next_creator, **kwargs): # Note: should support "colocate_with" argument. raise NotImplementedError("must be implemented in descendants") @@ -1471,11 +1472,12 @@ class StrategyExtendedV2(object): Returns: A context manager. """ - def create_colocated_variable(next_creator, *args, **kwargs): + + def create_colocated_variable(next_creator, **kwargs): _require_strategy_scope_extended(self) kwargs["use_resource"] = True kwargs["colocate_with"] = colocate_with_variable - return next_creator(*args, **kwargs) + return next_creator(**kwargs) _require_strategy_scope_extended(self) self._validate_colocate_with_variable(colocate_with_variable) @@ -2139,9 +2141,9 @@ class _DefaultDistributionContext(object): def __init__(self, strategy): - def creator(next_creator, *args, **kwargs): + def creator(next_creator, **kwargs): _require_strategy_scope_strategy(strategy) - return next_creator(*args, **kwargs) + return next_creator(**kwargs) self._var_creator_scope = variable_scope.variable_creator_scope(creator) self._strategy = strategy diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py index 1d171bc5cd5..8c7ad0ae40d 100644 --- a/tensorflow/python/distribute/distribute_lib_test.py +++ b/tensorflow/python/distribute/distribute_lib_test.py @@ -77,7 +77,7 @@ class _TestExtended(distribute_lib.StrategyExtendedV1): replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)): return fn(*args, **kwargs) - def _create_variable(self, next_creator, *args, **kwargs): + def _create_variable(self, next_creator, **kwargs): return _get_test_variable(kwargs["name"], kwargs["synchronization"], kwargs["aggregation"]) @@ -432,8 +432,8 @@ class _TestStrategy2(distribute_lib.Strategy): class _TestExtended2(_TestExtended): - def _create_variable(self, next_creator, *args, **kwargs): - return next_creator(*args, **kwargs) + def _create_variable(self, next_creator, **kwargs): + return next_creator(**kwargs) class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase): diff --git a/tensorflow/python/distribute/minimize_loss_test.py b/tensorflow/python/distribute/minimize_loss_test.py index 743614b096f..fb9aa61aa3f 100644 --- a/tensorflow/python/distribute/minimize_loss_test.py +++ b/tensorflow/python/distribute/minimize_loss_test.py @@ -172,8 +172,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): created_variables = [] trainable_variables = [] - def appending_creator(next_creator, *args, **kwargs): - v = next_creator(*args, **kwargs) + def appending_creator(next_creator, **kwargs): + v = next_creator(**kwargs) created_variables.append(v.name) if "trainable" in kwargs and kwargs["trainable"]: trainable_variables.append(v.name) diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 9e26421b824..22c4ba62afa 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -569,18 +569,18 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): return initial_value_fn - def _create_variable(self, next_creator, *args, **kwargs): + def _create_variable(self, next_creator, **kwargs): """Create a mirrored variable. See `DistributionStrategy.scope`.""" colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: devices = self._devices elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): - return next_creator(*args, **kwargs) + return next_creator(**kwargs) else: devices = colocate_with.devices - def _real_mirrored_creator(*args, **kwargs): # pylint: disable=g-missing-docstring + def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring value_list = [] for i, d in enumerate(devices): with ops.device(d): @@ -600,14 +600,15 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1): # Don't record operations (e.g. other variable reads) during # variable creation. with tape.stop_recording(): - v = next_creator(*args, **kwargs) + v = next_creator(**kwargs) assert not isinstance(v, values.DistributedVariable) value_list.append(v) return value_list - return values.create_mirrored_variable( - self._container_strategy(), _real_mirrored_creator, - values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs) + return values.create_mirrored_variable(self._container_strategy(), + _real_mirrored_creator, + values.MirroredVariable, + values.SyncOnReadVariable, **kwargs) def _validate_colocate_with_variable(self, colocate_with_variable): values.validate_colocate_distributed_variable(colocate_with_variable, self) diff --git a/tensorflow/python/distribute/mirrored_strategy_test.py b/tensorflow/python/distribute/mirrored_strategy_test.py index 29bf69d64ec..9446d898126 100644 --- a/tensorflow/python/distribute/mirrored_strategy_test.py +++ b/tensorflow/python/distribute/mirrored_strategy_test.py @@ -293,8 +293,8 @@ class MirroredStrategyVariableCreatorStackTest( def model_fn(): replica_id_str = str(self.evaluate(_replica_id())) - def thread_creator_fn(next_creator, *args, **kwargs): - return next_creator(*args, **kwargs) + ":thread_" + replica_id_str + def thread_creator_fn(next_creator, **kwargs): + return next_creator(**kwargs) + ":thread_" + replica_id_str with variable_scope.variable_creator_scope(thread_creator_fn): # Create a variable in this scope. @@ -304,9 +304,9 @@ class MirroredStrategyVariableCreatorStackTest( ds_context.get_replica_context().merge_call(lambda _: _) return v - def main_thread_creator(next_creator, *args, **kwargs): + def main_thread_creator(next_creator, **kwargs): # We are not using the underlying next_creator for test purposes. - del next_creator, args, kwargs + del next_creator, kwargs return "main_thread" with context.graph_mode(), \ diff --git a/tensorflow/python/distribute/numpy_dataset.py b/tensorflow/python/distribute/numpy_dataset.py index 5881e4cd59e..0d8df03f88c 100644 --- a/tensorflow/python/distribute/numpy_dataset.py +++ b/tensorflow/python/distribute/numpy_dataset.py @@ -75,9 +75,10 @@ def init_var_from_numpy(input_var, numpy_input, session): def one_host_numpy_dataset(numpy_input, colocate_with, session): """Create a dataset on `colocate_with` from `numpy_input`.""" - def create_colocated_variable(next_creator, *args, **kwargs): + + def create_colocated_variable(next_creator, **kwargs): kwargs["colocate_with"] = colocate_with - return next_creator(*args, **kwargs) + return next_creator(**kwargs) numpy_flat = nest.flatten(numpy_input) with variable_scope.variable_creator_scope(create_colocated_variable): diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py index 144ce6a8fce..2e52cfb457a 100644 --- a/tensorflow/python/distribute/one_device_strategy.py +++ b/tensorflow/python/distribute/one_device_strategy.py @@ -253,17 +253,17 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1): worker_device_pairs = [(self._input_device, [self._device])] self._input_workers = input_lib.InputWorkers(worker_device_pairs) - def _create_variable(self, next_creator, *args, **kwargs): + def _create_variable(self, next_creator, **kwargs): colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: with ops.device(self._device): - return next_creator(*args, **kwargs) + return next_creator(**kwargs) elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): - return next_creator(*args, **kwargs) + return next_creator(**kwargs) else: with ops.colocate_with(colocate_with): - return next_creator(*args, **kwargs) + return next_creator(**kwargs) def _validate_colocate_with_variable(self, colocate_with_variable): values.validate_colocate(colocate_with_variable, self) diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 900b6a5b453..72d25db474b 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -388,7 +388,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): # TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through # this creator, such as "MutableHashTable". - def _create_variable(self, next_creator, *args, **kwargs): + def _create_variable(self, next_creator, **kwargs): if self._num_replicas_in_sync > 1: aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE) if aggregation not in ( @@ -400,7 +400,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) - def var_creator(*args, **kwargs): + def var_creator(**kwargs): """Create an AggregatingVariable and fix up collections.""" # Record what collections this variable should be added to. collections = kwargs.pop("collections", None) @@ -409,7 +409,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): kwargs["collections"] = [] # Create and wrap the variable. - v = next_creator(*args, **kwargs) + v = next_creator(**kwargs) wrapped = values.AggregatingVariable( self._container_strategy(), v, aggregation) @@ -440,14 +440,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): colocate_with = kwargs["colocate_with"] if isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): - return var_creator(*args, **kwargs) + return var_creator(**kwargs) with ops.device(None): with ops.colocate_with(colocate_with): - return var_creator(*args, **kwargs) + return var_creator(**kwargs) with ops.colocate_with(None, ignore_existing=True): with ops.device(self._variable_device): - return var_creator(*args, **kwargs) + return var_creator(**kwargs) def _call_for_each_replica(self, fn, args, kwargs): # pylint: disable=protected-access diff --git a/tensorflow/python/distribute/shared_variable_creator.py b/tensorflow/python/distribute/shared_variable_creator.py index a7083e279f2..11ed271bf4c 100644 --- a/tensorflow/python/distribute/shared_variable_creator.py +++ b/tensorflow/python/distribute/shared_variable_creator.py @@ -63,19 +63,19 @@ def make_fn(shared_variable_store, device_id): variable_scope_access_index = {} assert isinstance(device_id, int) - def create_new_variable(next_creator, *args, **kwargs): + def create_new_variable(next_creator, **kwargs): """Create the variable using `next_creator` and store it.""" canonical_name = _canonicalize_variable_name(kwargs.get("name")) - v = next_creator(*args, **kwargs) + v = next_creator(**kwargs) if canonical_name not in shared_variable_store: shared_variable_store[canonical_name] = [] shared_variable_store[canonical_name].append(v) return v - def reuse_variable(next_creator, *args, **kwargs): + def reuse_variable(next_creator, **kwargs): """Re-use existing variable from store with same name (in order).""" - del next_creator, args + del next_creator name = kwargs.get("name") canonical_name = _canonicalize_variable_name(name) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index e7335e23c9a..0a127ca5167 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -508,21 +508,21 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): """ tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver) - def _create_variable(self, next_creator, *args, **kwargs): + def _create_variable(self, next_creator, **kwargs): """Create a TPUMirroredVariable. See `DistributionStrategy.scope`.""" if kwargs.pop("skip_mirrored_creator", False): - return next_creator(*args, **kwargs) + return next_creator(**kwargs) colocate_with = kwargs.pop("colocate_with", None) if colocate_with is None: devices = self._tpu_devices[:, self._logical_device_stack[-1]] elif isinstance(colocate_with, numpy_dataset.SingleDevice): with ops.device(colocate_with.device): - return next_creator(*args, **kwargs) + return next_creator(**kwargs) else: devices = colocate_with.devices - def _real_mirrored_creator(*args, **kwargs): # pylint: disable=g-missing-docstring + def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring initial_value = None value_list = [] for i, d in enumerate(devices): @@ -545,18 +545,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): kwargs["initial_value"] = initial_value with context.device_policy(context.DEVICE_PLACEMENT_SILENT): - v = next_creator(*args, **kwargs) + v = next_creator(**kwargs) assert not isinstance(v, values.TPUMirroredVariable) value_list.append(v) return value_list - return values.create_mirrored_variable( - self._container_strategy(), - _real_mirrored_creator, - values.TPUMirroredVariable, - values.TPUSyncOnReadVariable, - *args, **kwargs) + return values.create_mirrored_variable(self._container_strategy(), + _real_mirrored_creator, + values.TPUMirroredVariable, + values.TPUSyncOnReadVariable, + **kwargs) def _reduce_to(self, reduce_op, value, destinations): if (isinstance(value, values.DistributedValues) or diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 7c7a3b5505f..0126df3ae51 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -728,8 +728,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable): def create_mirrored_variable( # pylint: disable=missing-docstring - strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, - *args, **kwargs): + strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs): # Figure out what collections this variable should be added to. # We'll add the MirroredVariable to those collections instead. var_collections = kwargs.pop("collections", None) @@ -772,7 +771,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring # was never recorded on the tape instead of having to do this manually # here. with tape.stop_recording(): - value_list = real_mirrored_creator(*args, **kwargs) + value_list = real_mirrored_creator(**kwargs) var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls result = var_cls(strategy, value_list, aggregation)