Merges variable_scope.variable and tf.Variable

PiperOrigin-RevId: 205267974
This commit is contained in:
Alexandre Passos 2018-07-19 11:02:05 -07:00 committed by TensorFlower Gardener
parent 9fa89160a4
commit 88c520b3f1
4 changed files with 63 additions and 63 deletions

View File

@ -1294,16 +1294,3 @@ def is_resource_variable(var):
""""Returns True if `var` is to be considered a ResourceVariable."""
return isinstance(var, ResourceVariable) or hasattr(
var, "_should_act_as_resource_variable")
_DEFAULT_USE_RESOURCE = False
def _default_variable_creator(_, *args, **kwds):
use_resource = kwds.pop("use_resource", _DEFAULT_USE_RESOURCE)
use_resource = use_resource or context.executing_eagerly()
if use_resource:
return ResourceVariable(*args, **kwds)
return variables.RefVariable(*args, **kwds)
variables.default_variable_creator = _default_variable_creator

View File

@ -2349,7 +2349,10 @@ def default_variable_creator(next_creator=None, **kwargs):
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
name = kwargs.get("name", None)
variable_def = kwargs.get("variable_def", None)
dtype = kwargs.get("dtype", None)
expected_shape = kwargs.get("expected_shape", None)
import_scope = kwargs.get("import_scope", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
@ -2360,23 +2363,24 @@ def default_variable_creator(next_creator=None, **kwargs):
if use_resource is None:
use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.executing_eagerly()):
use_resource = use_resource or context.executing_eagerly()
if use_resource:
return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
constraint=constraint)
elif not use_resource and context.executing_eagerly():
raise RuntimeError(
"VariableScope should use resource variable when eager execution is"
" enabled, but use_resource is False."
)
constraint=constraint, variable_def=variable_def,
import_scope=import_scope)
else:
return variables.Variable(
return variables.RefVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
constraint=constraint)
constraint=constraint, variable_def=variable_def,
expected_shape=expected_shape, import_scope=import_scope)
variables.default_variable_creator = default_variable_creator
def _make_getter(captured_getter, captured_previous):
@ -2384,36 +2388,8 @@ def _make_getter(captured_getter, captured_previous):
return lambda **kwargs: captured_getter(captured_previous, **kwargs)
def variable(initial_value=None,
trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
dtype=None,
constraint=None,
use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = _make_getter(getter, previous_getter)
# Reset `aggregation` that is explicitly set as `None` to the enum None value.
if aggregation is None:
aggregation = VariableAggregation.NONE
return previous_getter(
initial_value=initial_value,
trainable=trainable,
collections=collections,
validate_shape=validate_shape,
caching_device=caching_device,
name=name,
dtype=dtype,
constraint=constraint,
use_resource=use_resource,
synchronization=synchronization,
aggregation=aggregation)
# TODO(apassos) remove forwarding symbol
variable = variables.Variable
@tf_contextlib.contextmanager

View File

@ -40,15 +40,15 @@ from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
def default_variable_creator(_, *args, **kwds):
del args, kwds
raise NotImplementedError("resource_variable_ops needs to be imported")
def default_variable_creator(_, **kwds):
del kwds
raise NotImplementedError("variable_scope needs to be imported")
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
def getter(*args, **kwargs):
return captured_getter(captured_previous, *args, **kwargs)
def getter(**kwargs):
return captured_getter(captured_previous, **kwargs)
return getter
@ -86,11 +86,48 @@ class VariableAggregation(enum.Enum):
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
def _variable_call(cls,
initial_value=None,
trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Call on Variable class. Useful to force the signature."""
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = _make_getter(getter, previous_getter)
# Reset `aggregation` that is explicitly set as `None` to the enum NONE.
if aggregation is None:
aggregation = VariableAggregation.NONE
return previous_getter(
initial_value=initial_value,
trainable=trainable,
collections=collections,
validate_shape=validate_shape,
caching_device=caching_device,
name=name,
variable_def=variable_def,
dtype=dtype,
expected_shape=expected_shape,
import_scope=import_scope,
constraint=constraint,
use_resource=use_resource,
synchronization=synchronization,
aggregation=aggregation)
def __call__(cls, *args, **kwargs):
if cls is Variable:
previous_getter = lambda *a, **k: default_variable_creator(None, *a, **k)
# TODO(apassos) use a stack of getters here
return previous_getter(*args, **kwargs)
return cls._variable_call(*args, **kwargs)
else:
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
@ -650,8 +687,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
@staticmethod
def from_proto(variable_def, import_scope=None):
"""Returns a `Variable` object created from `variable_def`."""
return Variable(variable_def=variable_def,
import_scope=import_scope)
return RefVariable(variable_def=variable_def,
import_scope=import_scope)
class SaveSliceInfo(object):
"""Information on how to save this Variable as a slice.

View File

@ -747,7 +747,7 @@ def capture_dependencies(template):
initial_value=initializer,
name=name,
**inner_kwargs)
if name.startswith(name_prefix):
if name is not None and name.startswith(name_prefix):
scope_stripped_name = name[len(name_prefix) + 1:]
if not checkpointable_parent:
return template._add_variable_with_custom_getter( # pylint: disable=protected-access