Merges variable_scope.variable and tf.Variable
PiperOrigin-RevId: 205267974
This commit is contained in:
parent
9fa89160a4
commit
88c520b3f1
@ -1294,16 +1294,3 @@ def is_resource_variable(var):
|
|||||||
""""Returns True if `var` is to be considered a ResourceVariable."""
|
""""Returns True if `var` is to be considered a ResourceVariable."""
|
||||||
return isinstance(var, ResourceVariable) or hasattr(
|
return isinstance(var, ResourceVariable) or hasattr(
|
||||||
var, "_should_act_as_resource_variable")
|
var, "_should_act_as_resource_variable")
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_USE_RESOURCE = False
|
|
||||||
|
|
||||||
|
|
||||||
def _default_variable_creator(_, *args, **kwds):
|
|
||||||
use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE)
|
|
||||||
use_resource = use_resource or context.executing_eagerly()
|
|
||||||
if use_resource:
|
|
||||||
return ResourceVariable(*args, **kwds)
|
|
||||||
return variables.RefVariable(*args, **kwds)
|
|
||||||
|
|
||||||
variables.default_variable_creator = _default_variable_creator
|
|
||||||
|
@ -2349,7 +2349,10 @@ def default_variable_creator(next_creator=None, **kwargs):
|
|||||||
validate_shape = kwargs.get("validate_shape", True)
|
validate_shape = kwargs.get("validate_shape", True)
|
||||||
caching_device = kwargs.get("caching_device", None)
|
caching_device = kwargs.get("caching_device", None)
|
||||||
name = kwargs.get("name", None)
|
name = kwargs.get("name", None)
|
||||||
|
variable_def = kwargs.get("variable_def", None)
|
||||||
dtype = kwargs.get("dtype", None)
|
dtype = kwargs.get("dtype", None)
|
||||||
|
expected_shape = kwargs.get("expected_shape", None)
|
||||||
|
import_scope = kwargs.get("import_scope", None)
|
||||||
constraint = kwargs.get("constraint", None)
|
constraint = kwargs.get("constraint", None)
|
||||||
use_resource = kwargs.get("use_resource", None)
|
use_resource = kwargs.get("use_resource", None)
|
||||||
|
|
||||||
@ -2360,23 +2363,24 @@ def default_variable_creator(next_creator=None, **kwargs):
|
|||||||
|
|
||||||
if use_resource is None:
|
if use_resource is None:
|
||||||
use_resource = get_variable_scope().use_resource
|
use_resource = get_variable_scope().use_resource
|
||||||
if use_resource or (use_resource is None and context.executing_eagerly()):
|
use_resource = use_resource or context.executing_eagerly()
|
||||||
|
if use_resource:
|
||||||
return resource_variable_ops.ResourceVariable(
|
return resource_variable_ops.ResourceVariable(
|
||||||
initial_value=initial_value, trainable=trainable,
|
initial_value=initial_value, trainable=trainable,
|
||||||
collections=collections, validate_shape=validate_shape,
|
collections=collections, validate_shape=validate_shape,
|
||||||
caching_device=caching_device, name=name, dtype=dtype,
|
caching_device=caching_device, name=name, dtype=dtype,
|
||||||
constraint=constraint)
|
constraint=constraint, variable_def=variable_def,
|
||||||
elif not use_resource and context.executing_eagerly():
|
import_scope=import_scope)
|
||||||
raise RuntimeError(
|
|
||||||
"VariableScope should use resource variable when eager execution is"
|
|
||||||
" enabled, but use_resource is False."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
return variables.Variable(
|
return variables.RefVariable(
|
||||||
initial_value=initial_value, trainable=trainable,
|
initial_value=initial_value, trainable=trainable,
|
||||||
collections=collections, validate_shape=validate_shape,
|
collections=collections, validate_shape=validate_shape,
|
||||||
caching_device=caching_device, name=name, dtype=dtype,
|
caching_device=caching_device, name=name, dtype=dtype,
|
||||||
constraint=constraint)
|
constraint=constraint, variable_def=variable_def,
|
||||||
|
expected_shape=expected_shape, import_scope=import_scope)
|
||||||
|
|
||||||
|
|
||||||
|
variables.default_variable_creator = default_variable_creator
|
||||||
|
|
||||||
|
|
||||||
def _make_getter(captured_getter, captured_previous):
|
def _make_getter(captured_getter, captured_previous):
|
||||||
@ -2384,36 +2388,8 @@ def _make_getter(captured_getter, captured_previous):
|
|||||||
return lambda **kwargs: captured_getter(captured_previous, **kwargs)
|
return lambda **kwargs: captured_getter(captured_previous, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def variable(initial_value=None,
|
# TODO(apassos) remove forwarding symbol
|
||||||
trainable=None,
|
variable = variables.Variable
|
||||||
collections=None,
|
|
||||||
validate_shape=True,
|
|
||||||
caching_device=None,
|
|
||||||
name=None,
|
|
||||||
dtype=None,
|
|
||||||
constraint=None,
|
|
||||||
use_resource=None,
|
|
||||||
synchronization=VariableSynchronization.AUTO,
|
|
||||||
aggregation=VariableAggregation.NONE):
|
|
||||||
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
|
|
||||||
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
|
|
||||||
previous_getter = _make_getter(getter, previous_getter)
|
|
||||||
|
|
||||||
# Reset `aggregation` that is explicitly set as `None` to the enum None value.
|
|
||||||
if aggregation is None:
|
|
||||||
aggregation = VariableAggregation.NONE
|
|
||||||
return previous_getter(
|
|
||||||
initial_value=initial_value,
|
|
||||||
trainable=trainable,
|
|
||||||
collections=collections,
|
|
||||||
validate_shape=validate_shape,
|
|
||||||
caching_device=caching_device,
|
|
||||||
name=name,
|
|
||||||
dtype=dtype,
|
|
||||||
constraint=constraint,
|
|
||||||
use_resource=use_resource,
|
|
||||||
synchronization=synchronization,
|
|
||||||
aggregation=aggregation)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
|
@ -40,15 +40,15 @@ from tensorflow.python.util.deprecation import deprecated
|
|||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
def default_variable_creator(_, *args, **kwds):
|
def default_variable_creator(_, **kwds):
|
||||||
del args, kwds
|
del kwds
|
||||||
raise NotImplementedError("resource_variable_ops needs to be imported")
|
raise NotImplementedError("variable_scope needs to be imported")
|
||||||
|
|
||||||
|
|
||||||
def _make_getter(captured_getter, captured_previous):
|
def _make_getter(captured_getter, captured_previous):
|
||||||
"""To avoid capturing loop variables."""
|
"""To avoid capturing loop variables."""
|
||||||
def getter(*args, **kwargs):
|
def getter(**kwargs):
|
||||||
return captured_getter(captured_previous, *args, **kwargs)
|
return captured_getter(captured_previous, **kwargs)
|
||||||
return getter
|
return getter
|
||||||
|
|
||||||
|
|
||||||
@ -86,11 +86,48 @@ class VariableAggregation(enum.Enum):
|
|||||||
class VariableMetaclass(type):
|
class VariableMetaclass(type):
|
||||||
"""Metaclass to allow construction of tf.Variable to be overridden."""
|
"""Metaclass to allow construction of tf.Variable to be overridden."""
|
||||||
|
|
||||||
|
def _variable_call(cls,
|
||||||
|
initial_value=None,
|
||||||
|
trainable=None,
|
||||||
|
collections=None,
|
||||||
|
validate_shape=True,
|
||||||
|
caching_device=None,
|
||||||
|
name=None,
|
||||||
|
variable_def=None,
|
||||||
|
dtype=None,
|
||||||
|
expected_shape=None,
|
||||||
|
import_scope=None,
|
||||||
|
constraint=None,
|
||||||
|
use_resource=None,
|
||||||
|
synchronization=VariableSynchronization.AUTO,
|
||||||
|
aggregation=VariableAggregation.NONE):
|
||||||
|
"""Call on Variable class. Useful to force the signature."""
|
||||||
|
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
|
||||||
|
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
|
||||||
|
previous_getter = _make_getter(getter, previous_getter)
|
||||||
|
|
||||||
|
# Reset `aggregation` that is explicitly set as `None` to the enum NONE.
|
||||||
|
if aggregation is None:
|
||||||
|
aggregation = VariableAggregation.NONE
|
||||||
|
return previous_getter(
|
||||||
|
initial_value=initial_value,
|
||||||
|
trainable=trainable,
|
||||||
|
collections=collections,
|
||||||
|
validate_shape=validate_shape,
|
||||||
|
caching_device=caching_device,
|
||||||
|
name=name,
|
||||||
|
variable_def=variable_def,
|
||||||
|
dtype=dtype,
|
||||||
|
expected_shape=expected_shape,
|
||||||
|
import_scope=import_scope,
|
||||||
|
constraint=constraint,
|
||||||
|
use_resource=use_resource,
|
||||||
|
synchronization=synchronization,
|
||||||
|
aggregation=aggregation)
|
||||||
|
|
||||||
def __call__(cls, *args, **kwargs):
|
def __call__(cls, *args, **kwargs):
|
||||||
if cls is Variable:
|
if cls is Variable:
|
||||||
previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k)
|
return cls._variable_call(*args, **kwargs)
|
||||||
# TODO(apassos) use a stack of getters here
|
|
||||||
return previous_getter(*args, **kwargs)
|
|
||||||
else:
|
else:
|
||||||
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
|
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
|
||||||
|
|
||||||
@ -650,7 +687,7 @@ class Variable(six.with_metaclass(VariableMetaclass,
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def from_proto(variable_def, import_scope=None):
|
def from_proto(variable_def, import_scope=None):
|
||||||
"""Returns a `Variable` object created from `variable_def`."""
|
"""Returns a `Variable` object created from `variable_def`."""
|
||||||
return Variable(variable_def=variable_def,
|
return RefVariable(variable_def=variable_def,
|
||||||
import_scope=import_scope)
|
import_scope=import_scope)
|
||||||
|
|
||||||
class SaveSliceInfo(object):
|
class SaveSliceInfo(object):
|
||||||
|
@ -747,7 +747,7 @@ def capture_dependencies(template):
|
|||||||
initial_value=initializer,
|
initial_value=initializer,
|
||||||
name=name,
|
name=name,
|
||||||
**inner_kwargs)
|
**inner_kwargs)
|
||||||
if name.startswith(name_prefix):
|
if name is not None and name.startswith(name_prefix):
|
||||||
scope_stripped_name = name[len(name_prefix) + 1:]
|
scope_stripped_name = name[len(name_prefix) + 1:]
|
||||||
if not checkpointable_parent:
|
if not checkpointable_parent:
|
||||||
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
|
return template._add_variable_with_custom_getter( # pylint: disable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user