Remove *args from disribute/ variable creators

The required signature of variable_creator is to only take **kwargs. *args makes
it confusing if you want to modify the arguments passed to the next creator.

PiperOrigin-RevId: 292739441
Change-Id: I0f56066a9cc927446a7e38a18f0d48aee290fb74
This commit is contained in:
Ran Chen 2020-02-01 16:55:18 -08:00 committed by TensorFlower Gardener
parent adcbdc2bcc
commit 1fa0f82f4c
11 changed files with 55 additions and 53 deletions

View File

@ -1369,7 +1369,8 @@ class StrategyExtendedV2(object):
def _scope(self, strategy):
"""Implementation of tf.distribute.Strategy.scope()."""
def creator_with_resource_vars(*args, **kwargs):
def creator_with_resource_vars(next_creator, **kwargs):
"""Variable creator to use in `_CurrentDistributionContext`."""
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
@ -1382,7 +1383,7 @@ class StrategyExtendedV2(object):
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
return self._create_variable(*args, **kwargs)
return self._create_variable(next_creator, **kwargs)
def distributed_getter(getter, *args, **kwargs):
if not self._allow_variable_partition():
@ -1402,7 +1403,7 @@ class StrategyExtendedV2(object):
def _allow_variable_partition(self):
return False
def _create_variable(self, next_creator, *args, **kwargs):
def _create_variable(self, next_creator, **kwargs):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
@ -1471,11 +1472,12 @@ class StrategyExtendedV2(object):
Returns:
A context manager.
"""
def create_colocated_variable(next_creator, *args, **kwargs):
def create_colocated_variable(next_creator, **kwargs):
_require_strategy_scope_extended(self)
kwargs["use_resource"] = True
kwargs["colocate_with"] = colocate_with_variable
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
_require_strategy_scope_extended(self)
self._validate_colocate_with_variable(colocate_with_variable)
@ -2139,9 +2141,9 @@ class _DefaultDistributionContext(object):
def __init__(self, strategy):
def creator(next_creator, *args, **kwargs):
def creator(next_creator, **kwargs):
_require_strategy_scope_strategy(strategy)
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
self._var_creator_scope = variable_scope.variable_creator_scope(creator)
self._strategy = strategy

View File

@ -77,7 +77,7 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
return fn(*args, **kwargs)
def _create_variable(self, next_creator, *args, **kwargs):
def _create_variable(self, next_creator, **kwargs):
return _get_test_variable(kwargs["name"], kwargs["synchronization"],
kwargs["aggregation"])
@ -432,8 +432,8 @@ class _TestStrategy2(distribute_lib.Strategy):
class _TestExtended2(_TestExtended):
def _create_variable(self, next_creator, *args, **kwargs):
return next_creator(*args, **kwargs)
def _create_variable(self, next_creator, **kwargs):
return next_creator(**kwargs)
class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):

View File

@ -172,8 +172,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
created_variables = []
trainable_variables = []
def appending_creator(next_creator, *args, **kwargs):
v = next_creator(*args, **kwargs)
def appending_creator(next_creator, **kwargs):
v = next_creator(**kwargs)
created_variables.append(v.name)
if "trainable" in kwargs and kwargs["trainable"]:
trainable_variables.append(v.name)

View File

@ -569,18 +569,18 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
return initial_value_fn
def _create_variable(self, next_creator, *args, **kwargs):
def _create_variable(self, next_creator, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
devices = self._devices
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
else:
devices = colocate_with.devices
def _real_mirrored_creator(*args, **kwargs): # pylint: disable=g-missing-docstring
def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring
value_list = []
for i, d in enumerate(devices):
with ops.device(d):
@ -600,14 +600,15 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
# Don't record operations (e.g. other variable reads) during
# variable creation.
with tape.stop_recording():
v = next_creator(*args, **kwargs)
v = next_creator(**kwargs)
assert not isinstance(v, values.DistributedVariable)
value_list.append(v)
return value_list
return values.create_mirrored_variable(
self._container_strategy(), _real_mirrored_creator,
values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs)
return values.create_mirrored_variable(self._container_strategy(),
_real_mirrored_creator,
values.MirroredVariable,
values.SyncOnReadVariable, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate_distributed_variable(colocate_with_variable, self)

View File

@ -293,8 +293,8 @@ class MirroredStrategyVariableCreatorStackTest(
def model_fn():
replica_id_str = str(self.evaluate(_replica_id()))
def thread_creator_fn(next_creator, *args, **kwargs):
return next_creator(*args, **kwargs) + ":thread_" + replica_id_str
def thread_creator_fn(next_creator, **kwargs):
return next_creator(**kwargs) + ":thread_" + replica_id_str
with variable_scope.variable_creator_scope(thread_creator_fn):
# Create a variable in this scope.
@ -304,9 +304,9 @@ class MirroredStrategyVariableCreatorStackTest(
ds_context.get_replica_context().merge_call(lambda _: _)
return v
def main_thread_creator(next_creator, *args, **kwargs):
def main_thread_creator(next_creator, **kwargs):
# We are not using the underlying next_creator for test purposes.
del next_creator, args, kwargs
del next_creator, kwargs
return "main_thread"
with context.graph_mode(), \

View File

@ -75,9 +75,10 @@ def init_var_from_numpy(input_var, numpy_input, session):
def one_host_numpy_dataset(numpy_input, colocate_with, session):
"""Create a dataset on `colocate_with` from `numpy_input`."""
def create_colocated_variable(next_creator, *args, **kwargs):
def create_colocated_variable(next_creator, **kwargs):
kwargs["colocate_with"] = colocate_with
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
numpy_flat = nest.flatten(numpy_input)
with variable_scope.variable_creator_scope(create_colocated_variable):

View File

@ -253,17 +253,17 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
worker_device_pairs = [(self._input_device, [self._device])]
self._input_workers = input_lib.InputWorkers(worker_device_pairs)
def _create_variable(self, next_creator, *args, **kwargs):
def _create_variable(self, next_creator, **kwargs):
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
with ops.device(self._device):
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
else:
with ops.colocate_with(colocate_with):
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)

View File

@ -388,7 +388,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
# TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through
# this creator, such as "MutableHashTable".
def _create_variable(self, next_creator, *args, **kwargs):
def _create_variable(self, next_creator, **kwargs):
if self._num_replicas_in_sync > 1:
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
if aggregation not in (
@ -400,7 +400,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
def var_creator(*args, **kwargs):
def var_creator(**kwargs):
"""Create an AggregatingVariable and fix up collections."""
# Record what collections this variable should be added to.
collections = kwargs.pop("collections", None)
@ -409,7 +409,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
kwargs["collections"] = []
# Create and wrap the variable.
v = next_creator(*args, **kwargs)
v = next_creator(**kwargs)
wrapped = values.AggregatingVariable(
self._container_strategy(), v, aggregation)
@ -440,14 +440,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
colocate_with = kwargs["colocate_with"]
if isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return var_creator(*args, **kwargs)
return var_creator(**kwargs)
with ops.device(None):
with ops.colocate_with(colocate_with):
return var_creator(*args, **kwargs)
return var_creator(**kwargs)
with ops.colocate_with(None, ignore_existing=True):
with ops.device(self._variable_device):
return var_creator(*args, **kwargs)
return var_creator(**kwargs)
def _call_for_each_replica(self, fn, args, kwargs):
# pylint: disable=protected-access

View File

@ -63,19 +63,19 @@ def make_fn(shared_variable_store, device_id):
variable_scope_access_index = {}
assert isinstance(device_id, int)
def create_new_variable(next_creator, *args, **kwargs):
def create_new_variable(next_creator, **kwargs):
"""Create the variable using `next_creator` and store it."""
canonical_name = _canonicalize_variable_name(kwargs.get("name"))
v = next_creator(*args, **kwargs)
v = next_creator(**kwargs)
if canonical_name not in shared_variable_store:
shared_variable_store[canonical_name] = []
shared_variable_store[canonical_name].append(v)
return v
def reuse_variable(next_creator, *args, **kwargs):
def reuse_variable(next_creator, **kwargs):
"""Re-use existing variable from store with same name (in order)."""
del next_creator, args
del next_creator
name = kwargs.get("name")
canonical_name = _canonicalize_variable_name(name)

View File

@ -508,21 +508,21 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
"""
tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
def _create_variable(self, next_creator, *args, **kwargs):
def _create_variable(self, next_creator, **kwargs):
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
if kwargs.pop("skip_mirrored_creator", False):
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
colocate_with = kwargs.pop("colocate_with", None)
if colocate_with is None:
devices = self._tpu_devices[:, self._logical_device_stack[-1]]
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
with ops.device(colocate_with.device):
return next_creator(*args, **kwargs)
return next_creator(**kwargs)
else:
devices = colocate_with.devices
def _real_mirrored_creator(*args, **kwargs): # pylint: disable=g-missing-docstring
def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring
initial_value = None
value_list = []
for i, d in enumerate(devices):
@ -545,18 +545,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
kwargs["initial_value"] = initial_value
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
v = next_creator(*args, **kwargs)
v = next_creator(**kwargs)
assert not isinstance(v, values.TPUMirroredVariable)
value_list.append(v)
return value_list
return values.create_mirrored_variable(
self._container_strategy(),
_real_mirrored_creator,
values.TPUMirroredVariable,
values.TPUSyncOnReadVariable,
*args, **kwargs)
return values.create_mirrored_variable(self._container_strategy(),
_real_mirrored_creator,
values.TPUMirroredVariable,
values.TPUSyncOnReadVariable,
**kwargs)
def _reduce_to(self, reduce_op, value, destinations):
if (isinstance(value, values.DistributedValues) or

View File

@ -728,8 +728,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
def create_mirrored_variable( # pylint: disable=missing-docstring
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls,
*args, **kwargs):
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
# Figure out what collections this variable should be added to.
# We'll add the MirroredVariable to those collections instead.
var_collections = kwargs.pop("collections", None)
@ -772,7 +771,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
# was never recorded on the tape instead of having to do this manually
# here.
with tape.stop_recording():
value_list = real_mirrored_creator(*args, **kwargs)
value_list = real_mirrored_creator(**kwargs)
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
result = var_cls(strategy, value_list, aggregation)