Remove *args from disribute/ variable creators
The required signature of variable_creator is to only take **kwargs. *args makes it confusing if you want to modify the arguments passed to the next creator. PiperOrigin-RevId: 292739441 Change-Id: I0f56066a9cc927446a7e38a18f0d48aee290fb74
This commit is contained in:
parent
adcbdc2bcc
commit
1fa0f82f4c
@ -1369,7 +1369,8 @@ class StrategyExtendedV2(object):
|
||||
|
||||
def _scope(self, strategy):
|
||||
"""Implementation of tf.distribute.Strategy.scope()."""
|
||||
def creator_with_resource_vars(*args, **kwargs):
|
||||
|
||||
def creator_with_resource_vars(next_creator, **kwargs):
|
||||
"""Variable creator to use in `_CurrentDistributionContext`."""
|
||||
_require_strategy_scope_extended(self)
|
||||
kwargs["use_resource"] = True
|
||||
@ -1382,7 +1383,7 @@ class StrategyExtendedV2(object):
|
||||
if isinstance(kwargs["initial_value"], trackable.CheckpointInitialValue):
|
||||
kwargs["initial_value"] = kwargs["initial_value"].wrapped_value
|
||||
|
||||
return self._create_variable(*args, **kwargs)
|
||||
return self._create_variable(next_creator, **kwargs)
|
||||
|
||||
def distributed_getter(getter, *args, **kwargs):
|
||||
if not self._allow_variable_partition():
|
||||
@ -1402,7 +1403,7 @@ class StrategyExtendedV2(object):
|
||||
def _allow_variable_partition(self):
|
||||
return False
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
# Note: should support "colocate_with" argument.
|
||||
raise NotImplementedError("must be implemented in descendants")
|
||||
|
||||
@ -1471,11 +1472,12 @@ class StrategyExtendedV2(object):
|
||||
Returns:
|
||||
A context manager.
|
||||
"""
|
||||
def create_colocated_variable(next_creator, *args, **kwargs):
|
||||
|
||||
def create_colocated_variable(next_creator, **kwargs):
|
||||
_require_strategy_scope_extended(self)
|
||||
kwargs["use_resource"] = True
|
||||
kwargs["colocate_with"] = colocate_with_variable
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
|
||||
_require_strategy_scope_extended(self)
|
||||
self._validate_colocate_with_variable(colocate_with_variable)
|
||||
@ -2139,9 +2141,9 @@ class _DefaultDistributionContext(object):
|
||||
|
||||
def __init__(self, strategy):
|
||||
|
||||
def creator(next_creator, *args, **kwargs):
|
||||
def creator(next_creator, **kwargs):
|
||||
_require_strategy_scope_strategy(strategy)
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
|
||||
self._var_creator_scope = variable_scope.variable_creator_scope(creator)
|
||||
self._strategy = strategy
|
||||
|
@ -77,7 +77,7 @@ class _TestExtended(distribute_lib.StrategyExtendedV1):
|
||||
replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
return _get_test_variable(kwargs["name"], kwargs["synchronization"],
|
||||
kwargs["aggregation"])
|
||||
|
||||
@ -432,8 +432,8 @@ class _TestStrategy2(distribute_lib.Strategy):
|
||||
|
||||
class _TestExtended2(_TestExtended):
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
return next_creator(*args, **kwargs)
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
return next_creator(**kwargs)
|
||||
|
||||
|
||||
class DefaultDistributionStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
|
@ -172,8 +172,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
created_variables = []
|
||||
trainable_variables = []
|
||||
|
||||
def appending_creator(next_creator, *args, **kwargs):
|
||||
v = next_creator(*args, **kwargs)
|
||||
def appending_creator(next_creator, **kwargs):
|
||||
v = next_creator(**kwargs)
|
||||
created_variables.append(v.name)
|
||||
if "trainable" in kwargs and kwargs["trainable"]:
|
||||
trainable_variables.append(v.name)
|
||||
|
@ -569,18 +569,18 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
return initial_value_fn
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
|
||||
colocate_with = kwargs.pop("colocate_with", None)
|
||||
if colocate_with is None:
|
||||
devices = self._devices
|
||||
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
|
||||
with ops.device(colocate_with.device):
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
else:
|
||||
devices = colocate_with.devices
|
||||
|
||||
def _real_mirrored_creator(*args, **kwargs): # pylint: disable=g-missing-docstring
|
||||
def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring
|
||||
value_list = []
|
||||
for i, d in enumerate(devices):
|
||||
with ops.device(d):
|
||||
@ -600,14 +600,15 @@ class MirroredExtended(distribute_lib.StrategyExtendedV1):
|
||||
# Don't record operations (e.g. other variable reads) during
|
||||
# variable creation.
|
||||
with tape.stop_recording():
|
||||
v = next_creator(*args, **kwargs)
|
||||
v = next_creator(**kwargs)
|
||||
assert not isinstance(v, values.DistributedVariable)
|
||||
value_list.append(v)
|
||||
return value_list
|
||||
|
||||
return values.create_mirrored_variable(
|
||||
self._container_strategy(), _real_mirrored_creator,
|
||||
values.MirroredVariable, values.SyncOnReadVariable, *args, **kwargs)
|
||||
return values.create_mirrored_variable(self._container_strategy(),
|
||||
_real_mirrored_creator,
|
||||
values.MirroredVariable,
|
||||
values.SyncOnReadVariable, **kwargs)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate_distributed_variable(colocate_with_variable, self)
|
||||
|
@ -293,8 +293,8 @@ class MirroredStrategyVariableCreatorStackTest(
|
||||
def model_fn():
|
||||
replica_id_str = str(self.evaluate(_replica_id()))
|
||||
|
||||
def thread_creator_fn(next_creator, *args, **kwargs):
|
||||
return next_creator(*args, **kwargs) + ":thread_" + replica_id_str
|
||||
def thread_creator_fn(next_creator, **kwargs):
|
||||
return next_creator(**kwargs) + ":thread_" + replica_id_str
|
||||
|
||||
with variable_scope.variable_creator_scope(thread_creator_fn):
|
||||
# Create a variable in this scope.
|
||||
@ -304,9 +304,9 @@ class MirroredStrategyVariableCreatorStackTest(
|
||||
ds_context.get_replica_context().merge_call(lambda _: _)
|
||||
return v
|
||||
|
||||
def main_thread_creator(next_creator, *args, **kwargs):
|
||||
def main_thread_creator(next_creator, **kwargs):
|
||||
# We are not using the underlying next_creator for test purposes.
|
||||
del next_creator, args, kwargs
|
||||
del next_creator, kwargs
|
||||
return "main_thread"
|
||||
|
||||
with context.graph_mode(), \
|
||||
|
@ -75,9 +75,10 @@ def init_var_from_numpy(input_var, numpy_input, session):
|
||||
|
||||
def one_host_numpy_dataset(numpy_input, colocate_with, session):
|
||||
"""Create a dataset on `colocate_with` from `numpy_input`."""
|
||||
def create_colocated_variable(next_creator, *args, **kwargs):
|
||||
|
||||
def create_colocated_variable(next_creator, **kwargs):
|
||||
kwargs["colocate_with"] = colocate_with
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
|
||||
numpy_flat = nest.flatten(numpy_input)
|
||||
with variable_scope.variable_creator_scope(create_colocated_variable):
|
||||
|
@ -253,17 +253,17 @@ class OneDeviceExtended(distribute_lib.StrategyExtendedV1):
|
||||
worker_device_pairs = [(self._input_device, [self._device])]
|
||||
self._input_workers = input_lib.InputWorkers(worker_device_pairs)
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
colocate_with = kwargs.pop("colocate_with", None)
|
||||
if colocate_with is None:
|
||||
with ops.device(self._device):
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
|
||||
with ops.device(colocate_with.device):
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
else:
|
||||
with ops.colocate_with(colocate_with):
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
values.validate_colocate(colocate_with_variable, self)
|
||||
|
@ -388,7 +388,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
# TODO(yuefengz): Not all ops in device_setter.STANDARD_PS_OPS will go through
|
||||
# this creator, such as "MutableHashTable".
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
if self._num_replicas_in_sync > 1:
|
||||
aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
|
||||
if aggregation not in (
|
||||
@ -400,7 +400,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
raise ValueError("Invalid variable aggregation mode: " + aggregation +
|
||||
" for variable: " + kwargs["name"])
|
||||
|
||||
def var_creator(*args, **kwargs):
|
||||
def var_creator(**kwargs):
|
||||
"""Create an AggregatingVariable and fix up collections."""
|
||||
# Record what collections this variable should be added to.
|
||||
collections = kwargs.pop("collections", None)
|
||||
@ -409,7 +409,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
kwargs["collections"] = []
|
||||
|
||||
# Create and wrap the variable.
|
||||
v = next_creator(*args, **kwargs)
|
||||
v = next_creator(**kwargs)
|
||||
wrapped = values.AggregatingVariable(
|
||||
self._container_strategy(), v, aggregation)
|
||||
|
||||
@ -440,14 +440,14 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
colocate_with = kwargs["colocate_with"]
|
||||
if isinstance(colocate_with, numpy_dataset.SingleDevice):
|
||||
with ops.device(colocate_with.device):
|
||||
return var_creator(*args, **kwargs)
|
||||
return var_creator(**kwargs)
|
||||
with ops.device(None):
|
||||
with ops.colocate_with(colocate_with):
|
||||
return var_creator(*args, **kwargs)
|
||||
return var_creator(**kwargs)
|
||||
|
||||
with ops.colocate_with(None, ignore_existing=True):
|
||||
with ops.device(self._variable_device):
|
||||
return var_creator(*args, **kwargs)
|
||||
return var_creator(**kwargs)
|
||||
|
||||
def _call_for_each_replica(self, fn, args, kwargs):
|
||||
# pylint: disable=protected-access
|
||||
|
@ -63,19 +63,19 @@ def make_fn(shared_variable_store, device_id):
|
||||
variable_scope_access_index = {}
|
||||
assert isinstance(device_id, int)
|
||||
|
||||
def create_new_variable(next_creator, *args, **kwargs):
|
||||
def create_new_variable(next_creator, **kwargs):
|
||||
"""Create the variable using `next_creator` and store it."""
|
||||
canonical_name = _canonicalize_variable_name(kwargs.get("name"))
|
||||
v = next_creator(*args, **kwargs)
|
||||
v = next_creator(**kwargs)
|
||||
|
||||
if canonical_name not in shared_variable_store:
|
||||
shared_variable_store[canonical_name] = []
|
||||
shared_variable_store[canonical_name].append(v)
|
||||
return v
|
||||
|
||||
def reuse_variable(next_creator, *args, **kwargs):
|
||||
def reuse_variable(next_creator, **kwargs):
|
||||
"""Re-use existing variable from store with same name (in order)."""
|
||||
del next_creator, args
|
||||
del next_creator
|
||||
name = kwargs.get("name")
|
||||
canonical_name = _canonicalize_variable_name(name)
|
||||
|
||||
|
@ -508,21 +508,21 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
"""
|
||||
tpu_strategy_util.initialize_tpu_system(self._tpu_cluster_resolver)
|
||||
|
||||
def _create_variable(self, next_creator, *args, **kwargs):
|
||||
def _create_variable(self, next_creator, **kwargs):
|
||||
"""Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
|
||||
if kwargs.pop("skip_mirrored_creator", False):
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
|
||||
colocate_with = kwargs.pop("colocate_with", None)
|
||||
if colocate_with is None:
|
||||
devices = self._tpu_devices[:, self._logical_device_stack[-1]]
|
||||
elif isinstance(colocate_with, numpy_dataset.SingleDevice):
|
||||
with ops.device(colocate_with.device):
|
||||
return next_creator(*args, **kwargs)
|
||||
return next_creator(**kwargs)
|
||||
else:
|
||||
devices = colocate_with.devices
|
||||
|
||||
def _real_mirrored_creator(*args, **kwargs): # pylint: disable=g-missing-docstring
|
||||
def _real_mirrored_creator(**kwargs): # pylint: disable=g-missing-docstring
|
||||
initial_value = None
|
||||
value_list = []
|
||||
for i, d in enumerate(devices):
|
||||
@ -545,18 +545,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
kwargs["initial_value"] = initial_value
|
||||
|
||||
with context.device_policy(context.DEVICE_PLACEMENT_SILENT):
|
||||
v = next_creator(*args, **kwargs)
|
||||
v = next_creator(**kwargs)
|
||||
|
||||
assert not isinstance(v, values.TPUMirroredVariable)
|
||||
value_list.append(v)
|
||||
return value_list
|
||||
|
||||
return values.create_mirrored_variable(
|
||||
self._container_strategy(),
|
||||
_real_mirrored_creator,
|
||||
values.TPUMirroredVariable,
|
||||
values.TPUSyncOnReadVariable,
|
||||
*args, **kwargs)
|
||||
return values.create_mirrored_variable(self._container_strategy(),
|
||||
_real_mirrored_creator,
|
||||
values.TPUMirroredVariable,
|
||||
values.TPUSyncOnReadVariable,
|
||||
**kwargs)
|
||||
|
||||
def _reduce_to(self, reduce_op, value, destinations):
|
||||
if (isinstance(value, values.DistributedValues) or
|
||||
|
@ -728,8 +728,7 @@ class _MirroredSaveable(saver.BaseSaverBuilder.ResourceVariableSaveable):
|
||||
|
||||
|
||||
def create_mirrored_variable( # pylint: disable=missing-docstring
|
||||
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls,
|
||||
*args, **kwargs):
|
||||
strategy, real_mirrored_creator, mirrored_cls, sync_on_read_cls, **kwargs):
|
||||
# Figure out what collections this variable should be added to.
|
||||
# We'll add the MirroredVariable to those collections instead.
|
||||
var_collections = kwargs.pop("collections", None)
|
||||
@ -772,7 +771,7 @@ def create_mirrored_variable( # pylint: disable=missing-docstring
|
||||
# was never recorded on the tape instead of having to do this manually
|
||||
# here.
|
||||
with tape.stop_recording():
|
||||
value_list = real_mirrored_creator(*args, **kwargs)
|
||||
value_list = real_mirrored_creator(**kwargs)
|
||||
var_cls = sync_on_read_cls if is_sync_on_read else mirrored_cls
|
||||
result = var_cls(strategy, value_list, aggregation)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user