From 88c520b3f12a1ee5e63d9f05094ca9f84700ea6e Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Thu, 19 Jul 2018 11:02:05 -0700 Subject: [PATCH] Merges variable_scope.variable and tf.Variable PiperOrigin-RevId: 205267974 --- .../python/ops/resource_variable_ops.py | 13 ----- tensorflow/python/ops/variable_scope.py | 54 +++++------------- tensorflow/python/ops/variables.py | 57 +++++++++++++++---- .../python/training/checkpointable/util.py | 2 +- 4 files changed, 63 insertions(+), 63 deletions(-) diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 5979b76ff24..1f56ad25bf2 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -1294,16 +1294,3 @@ def is_resource_variable(var): """"Returns True if `var` is to be considered a ResourceVariable.""" return isinstance(var, ResourceVariable) or hasattr( var, "_should_act_as_resource_variable") - - -_DEFAULT_USE_RESOURCE = False - - -def _default_variable_creator(_, *args, **kwds): - use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE) - use_resource = use_resource or context.executing_eagerly() - if use_resource: - return ResourceVariable(*args, **kwds) - return variables.RefVariable(*args, **kwds) - -variables.default_variable_creator = _default_variable_creator diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 0f37dcc0277..aca44bcd449 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -2349,7 +2349,10 @@ def default_variable_creator(next_creator=None, **kwargs): validate_shape = kwargs.get("validate_shape", True) caching_device = kwargs.get("caching_device", None) name = kwargs.get("name", None) + variable_def = kwargs.get("variable_def", None) dtype = kwargs.get("dtype", None) + expected_shape = kwargs.get("expected_shape", None) + import_scope = kwargs.get("import_scope", None) constraint = kwargs.get("constraint", None) use_resource = kwargs.get("use_resource", None) @@ -2360,23 +2363,24 @@ def default_variable_creator(next_creator=None, **kwargs): if use_resource is None: use_resource = get_variable_scope().use_resource - if use_resource or (use_resource is None and context.executing_eagerly()): + use_resource = use_resource or context.executing_eagerly() + if use_resource: return resource_variable_ops.ResourceVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, caching_device=caching_device, name=name, dtype=dtype, - constraint=constraint) - elif not use_resource and context.executing_eagerly(): - raise RuntimeError( - "VariableScope should use resource variable when eager execution is" - " enabled, but use_resource is False." - ) + constraint=constraint, variable_def=variable_def, + import_scope=import_scope) else: - return variables.Variable( + return variables.RefVariable( initial_value=initial_value, trainable=trainable, collections=collections, validate_shape=validate_shape, caching_device=caching_device, name=name, dtype=dtype, - constraint=constraint) + constraint=constraint, variable_def=variable_def, + expected_shape=expected_shape, import_scope=import_scope) + + +variables.default_variable_creator = default_variable_creator def _make_getter(captured_getter, captured_previous): @@ -2384,36 +2388,8 @@ def _make_getter(captured_getter, captured_previous): return lambda **kwargs: captured_getter(captured_previous, **kwargs) -def variable(initial_value=None, - trainable=None, - collections=None, - validate_shape=True, - caching_device=None, - name=None, - dtype=None, - constraint=None, - use_resource=None, - synchronization=VariableSynchronization.AUTO, - aggregation=VariableAggregation.NONE): - previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) - for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access - previous_getter = _make_getter(getter, previous_getter) - - # Reset `aggregation` that is explicitly set as `None` to the enum None value. - if aggregation is None: - aggregation = VariableAggregation.NONE - return previous_getter( - initial_value=initial_value, - trainable=trainable, - collections=collections, - validate_shape=validate_shape, - caching_device=caching_device, - name=name, - dtype=dtype, - constraint=constraint, - use_resource=use_resource, - synchronization=synchronization, - aggregation=aggregation) +# TODO(apassos) remove forwarding symbol +variable = variables.Variable @tf_contextlib.contextmanager diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 6bb2d6f6696..d03d93beeb1 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -40,15 +40,15 @@ from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export -def default_variable_creator(_, *args, **kwds): - del args, kwds - raise NotImplementedError("resource_variable_ops needs to be imported") +def default_variable_creator(_, **kwds): + del kwds + raise NotImplementedError("variable_scope needs to be imported") def _make_getter(captured_getter, captured_previous): """To avoid capturing loop variables.""" - def getter(*args, **kwargs): - return captured_getter(captured_previous, *args, **kwargs) + def getter(**kwargs): + return captured_getter(captured_previous, **kwargs) return getter @@ -86,11 +86,48 @@ class VariableAggregation(enum.Enum): class VariableMetaclass(type): """Metaclass to allow construction of tf.Variable to be overridden.""" + def _variable_call(cls, + initial_value=None, + trainable=None, + collections=None, + validate_shape=True, + caching_device=None, + name=None, + variable_def=None, + dtype=None, + expected_shape=None, + import_scope=None, + constraint=None, + use_resource=None, + synchronization=VariableSynchronization.AUTO, + aggregation=VariableAggregation.NONE): + """Call on Variable class. Useful to force the signature.""" + previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs) + for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access + previous_getter = _make_getter(getter, previous_getter) + + # Reset `aggregation` that is explicitly set as `None` to the enum NONE. + if aggregation is None: + aggregation = VariableAggregation.NONE + return previous_getter( + initial_value=initial_value, + trainable=trainable, + collections=collections, + validate_shape=validate_shape, + caching_device=caching_device, + name=name, + variable_def=variable_def, + dtype=dtype, + expected_shape=expected_shape, + import_scope=import_scope, + constraint=constraint, + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) + def __call__(cls, *args, **kwargs): if cls is Variable: - previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k) - # TODO(apassos) use a stack of getters here - return previous_getter(*args, **kwargs) + return cls._variable_call(*args, **kwargs) else: return super(VariableMetaclass, cls).__call__(*args, **kwargs) @@ -650,8 +687,8 @@ class Variable(six.with_metaclass(VariableMetaclass, @staticmethod def from_proto(variable_def, import_scope=None): """Returns a `Variable` object created from `variable_def`.""" - return Variable(variable_def=variable_def, - import_scope=import_scope) + return RefVariable(variable_def=variable_def, + import_scope=import_scope) class SaveSliceInfo(object): """Information on how to save this Variable as a slice. diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 6ae5765b133..686232fe270 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -747,7 +747,7 @@ def capture_dependencies(template): initial_value=initializer, name=name, **inner_kwargs) - if name.startswith(name_prefix): + if name is not None and name.startswith(name_prefix): scope_stripped_name = name[len(name_prefix) + 1:] if not checkpointable_parent: return template._add_variable_with_custom_getter( # pylint: disable=protected-access