Merges variable_scope.variable and tf.Variable
PiperOrigin-RevId: 205267974
This commit is contained in:
parent
9fa89160a4
commit
88c520b3f1
@ -1294,16 +1294,3 @@ def is_resource_variable(var):
|
||||
""""Returns True if `var` is to be considered a ResourceVariable."""
|
||||
return isinstance(var, ResourceVariable) or hasattr(
|
||||
var, "_should_act_as_resource_variable")
|
||||
|
||||
|
||||
_DEFAULT_USE_RESOURCE = False
|
||||
|
||||
|
||||
def _default_variable_creator(_, *args, **kwds):
|
||||
use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE)
|
||||
use_resource = use_resource or context.executing_eagerly()
|
||||
if use_resource:
|
||||
return ResourceVariable(*args, **kwds)
|
||||
return variables.RefVariable(*args, **kwds)
|
||||
|
||||
variables.default_variable_creator = _default_variable_creator
|
||||
|
@ -2349,7 +2349,10 @@ def default_variable_creator(next_creator=None, **kwargs):
|
||||
validate_shape = kwargs.get("validate_shape", True)
|
||||
caching_device = kwargs.get("caching_device", None)
|
||||
name = kwargs.get("name", None)
|
||||
variable_def = kwargs.get("variable_def", None)
|
||||
dtype = kwargs.get("dtype", None)
|
||||
expected_shape = kwargs.get("expected_shape", None)
|
||||
import_scope = kwargs.get("import_scope", None)
|
||||
constraint = kwargs.get("constraint", None)
|
||||
use_resource = kwargs.get("use_resource", None)
|
||||
|
||||
@ -2360,23 +2363,24 @@ def default_variable_creator(next_creator=None, **kwargs):
|
||||
|
||||
if use_resource is None:
|
||||
use_resource = get_variable_scope().use_resource
|
||||
if use_resource or (use_resource is None and context.executing_eagerly()):
|
||||
use_resource = use_resource or context.executing_eagerly()
|
||||
if use_resource:
|
||||
return resource_variable_ops.ResourceVariable(
|
||||
initial_value=initial_value, trainable=trainable,
|
||||
collections=collections, validate_shape=validate_shape,
|
||||
caching_device=caching_device, name=name, dtype=dtype,
|
||||
constraint=constraint)
|
||||
elif not use_resource and context.executing_eagerly():
|
||||
raise RuntimeError(
|
||||
"VariableScope should use resource variable when eager execution is"
|
||||
" enabled, but use_resource is False."
|
||||
)
|
||||
constraint=constraint, variable_def=variable_def,
|
||||
import_scope=import_scope)
|
||||
else:
|
||||
return variables.Variable(
|
||||
return variables.RefVariable(
|
||||
initial_value=initial_value, trainable=trainable,
|
||||
collections=collections, validate_shape=validate_shape,
|
||||
caching_device=caching_device, name=name, dtype=dtype,
|
||||
constraint=constraint)
|
||||
constraint=constraint, variable_def=variable_def,
|
||||
expected_shape=expected_shape, import_scope=import_scope)
|
||||
|
||||
|
||||
variables.default_variable_creator = default_variable_creator
|
||||
|
||||
|
||||
def _make_getter(captured_getter, captured_previous):
|
||||
@ -2384,36 +2388,8 @@ def _make_getter(captured_getter, captured_previous):
|
||||
return lambda **kwargs: captured_getter(captured_previous, **kwargs)
|
||||
|
||||
|
||||
def variable(initial_value=None,
|
||||
trainable=None,
|
||||
collections=None,
|
||||
validate_shape=True,
|
||||
caching_device=None,
|
||||
name=None,
|
||||
dtype=None,
|
||||
constraint=None,
|
||||
use_resource=None,
|
||||
synchronization=VariableSynchronization.AUTO,
|
||||
aggregation=VariableAggregation.NONE):
|
||||
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
|
||||
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
|
||||
previous_getter = _make_getter(getter, previous_getter)
|
||||
|
||||
# Reset `aggregation` that is explicitly set as `None` to the enum None value.
|
||||
if aggregation is None:
|
||||
aggregation = VariableAggregation.NONE
|
||||
return previous_getter(
|
||||
initial_value=initial_value,
|
||||
trainable=trainable,
|
||||
collections=collections,
|
||||
validate_shape=validate_shape,
|
||||
caching_device=caching_device,
|
||||
name=name,
|
||||
dtype=dtype,
|
||||
constraint=constraint,
|
||||
use_resource=use_resource,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation)
|
||||
# TODO(apassos) remove forwarding symbol
|
||||
variable = variables.Variable
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
|
@ -40,15 +40,15 @@ from tensorflow.python.util.deprecation import deprecated
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def default_variable_creator(_, *args, **kwds):
|
||||
del args, kwds
|
||||
raise NotImplementedError("resource_variable_ops needs to be imported")
|
||||
def default_variable_creator(_, **kwds):
|
||||
del kwds
|
||||
raise NotImplementedError("variable_scope needs to be imported")
|
||||
|
||||
|
||||
def _make_getter(captured_getter, captured_previous):
|
||||
"""To avoid capturing loop variables."""
|
||||
def getter(*args, **kwargs):
|
||||
return captured_getter(captured_previous, *args, **kwargs)
|
||||
def getter(**kwargs):
|
||||
return captured_getter(captured_previous, **kwargs)
|
||||
return getter
|
||||
|
||||
|
||||
@ -86,11 +86,48 @@ class VariableAggregation(enum.Enum):
|
||||
class VariableMetaclass(type):
|
||||
"""Metaclass to allow construction of tf.Variable to be overridden."""
|
||||
|
||||
def _variable_call(cls,
|
||||
initial_value=None,
|
||||
trainable=None,
|
||||
collections=None,
|
||||
validate_shape=True,
|
||||
caching_device=None,
|
||||
name=None,
|
||||
variable_def=None,
|
||||
dtype=None,
|
||||
expected_shape=None,
|
||||
import_scope=None,
|
||||
constraint=None,
|
||||
use_resource=None,
|
||||
synchronization=VariableSynchronization.AUTO,
|
||||
aggregation=VariableAggregation.NONE):
|
||||
"""Call on Variable class. Useful to force the signature."""
|
||||
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
|
||||
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
|
||||
previous_getter = _make_getter(getter, previous_getter)
|
||||
|
||||
# Reset `aggregation` that is explicitly set as `None` to the enum NONE.
|
||||
if aggregation is None:
|
||||
aggregation = VariableAggregation.NONE
|
||||
return previous_getter(
|
||||
initial_value=initial_value,
|
||||
trainable=trainable,
|
||||
collections=collections,
|
||||
validate_shape=validate_shape,
|
||||
caching_device=caching_device,
|
||||
name=name,
|
||||
variable_def=variable_def,
|
||||
dtype=dtype,
|
||||
expected_shape=expected_shape,
|
||||
import_scope=import_scope,
|
||||
constraint=constraint,
|
||||
use_resource=use_resource,
|
||||
synchronization=synchronization,
|
||||
aggregation=aggregation)
|
||||
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if cls is Variable:
|
||||
previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k)
|
||||
# TODO(apassos) use a stack of getters here
|
||||
return previous_getter(*args, **kwargs)
|
||||
return cls._variable_call(*args, **kwargs)
|
||||
else:
|
||||
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
|
||||
|
||||
@ -650,8 +687,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
||||
@staticmethod
|
||||
def from_proto(variable_def, import_scope=None):
|
||||
"""Returns a `Variable` object created from `variable_def`."""
|
||||
return Variable(variable_def=variable_def,
|
||||
import_scope=import_scope)
|
||||
return RefVariable(variable_def=variable_def,
|
||||
import_scope=import_scope)
|
||||
|
||||
class SaveSliceInfo(object):
|
||||
"""Information on how to save this Variable as a slice.
|
||||
|
@ -747,7 +747,7 @@ def capture_dependencies(template):
|
||||
initial_value=initializer,
|
||||
name=name,
|
||||
**inner_kwargs)
|
||||
if name.startswith(name_prefix):
|
||||
if name is not None and name.startswith(name_prefix):
|
||||
scope_stripped_name = name[len(name_prefix) + 1:]
|
||||
if not checkpointable_parent:
|
||||
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user