Clean up the ResourceVariable inheritance hierarchy a bit.
PiperOrigin-RevId: 253592362
This commit is contained in:
parent
a5b860a7cc
commit
aa93ea6441
@ -63,7 +63,7 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
|
|||||||
return handle
|
return handle
|
||||||
|
|
||||||
|
|
||||||
class SharedVariable(resource_variable_ops.ResourceVariable):
|
class SharedVariable(resource_variable_ops.BaseResourceVariable):
|
||||||
"""Experimental Variable designed for parameter server training.
|
"""Experimental Variable designed for parameter server training.
|
||||||
|
|
||||||
A SharedVariable has a name and two instances of SharedVariable with the
|
A SharedVariable has a name and two instances of SharedVariable with the
|
||||||
@ -231,7 +231,7 @@ class SharedVariable(resource_variable_ops.ResourceVariable):
|
|||||||
self._graph_element = None
|
self._graph_element = None
|
||||||
self._cached_value = None
|
self._cached_value = None
|
||||||
|
|
||||||
self._handle_deleter = None
|
self._handle_deleter = object()
|
||||||
self._cached_shape_as_list = None
|
self._cached_shape_as_list = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ def check_destinations(destinations):
|
|||||||
Boolean which is True if `destinations` is not empty.
|
Boolean which is True if `destinations` is not empty.
|
||||||
"""
|
"""
|
||||||
# Calling bool() on a ResourceVariable is not allowed.
|
# Calling bool() on a ResourceVariable is not allowed.
|
||||||
if isinstance(destinations, resource_variable_ops.ResourceVariable):
|
if isinstance(destinations, resource_variable_ops.BaseResourceVariable):
|
||||||
return bool(destinations.device)
|
return bool(destinations.device)
|
||||||
return bool(destinations)
|
return bool(destinations)
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ def check_destinations(destinations):
|
|||||||
def validate_destinations(destinations):
|
def validate_destinations(destinations):
|
||||||
if not isinstance(destinations,
|
if not isinstance(destinations,
|
||||||
(value_lib.DistributedValues,
|
(value_lib.DistributedValues,
|
||||||
resource_variable_ops.ResourceVariable,
|
resource_variable_ops.BaseResourceVariable,
|
||||||
value_lib.AggregatingVariable,
|
value_lib.AggregatingVariable,
|
||||||
six.string_types,
|
six.string_types,
|
||||||
value_lib.TPUMirroredVariable,
|
value_lib.TPUMirroredVariable,
|
||||||
|
@ -479,7 +479,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
def _update(self, var, fn, args, kwargs, group):
|
def _update(self, var, fn, args, kwargs, group):
|
||||||
if isinstance(var, values.AggregatingVariable):
|
if isinstance(var, values.AggregatingVariable):
|
||||||
var = var.get()
|
var = var.get()
|
||||||
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
if not isinstance(var, resource_variable_ops.BaseResourceVariable):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You can not update `var` %r. It must be a Variable." % var)
|
"You can not update `var` %r. It must be a Variable." % var)
|
||||||
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
|
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
|
||||||
|
@ -519,7 +519,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
|
|
||||||
def _update(self, var, fn, args, kwargs, group):
|
def _update(self, var, fn, args, kwargs, group):
|
||||||
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
|
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
|
||||||
var, resource_variable_ops.ResourceVariable)
|
var, resource_variable_ops.BaseResourceVariable)
|
||||||
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
||||||
if group:
|
if group:
|
||||||
return fn(var, *args, **kwargs)
|
return fn(var, *args, **kwargs)
|
||||||
@ -540,7 +540,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||||||
|
|
||||||
def read_var(self, var):
|
def read_var(self, var):
|
||||||
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
|
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
|
||||||
var, resource_variable_ops.ResourceVariable)
|
var, resource_variable_ops.BaseResourceVariable)
|
||||||
return var.read_value()
|
return var.read_value()
|
||||||
|
|
||||||
def _local_results(self, val):
|
def _local_results(self, val):
|
||||||
|
@ -111,6 +111,8 @@ class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
|||||||
shape and `validate_shape` is `True`.
|
shape and `validate_shape` is `True`.
|
||||||
RuntimeError: If called outside of a function definition.
|
RuntimeError: If called outside of a function definition.
|
||||||
"""
|
"""
|
||||||
|
with ops.init_scope():
|
||||||
|
self._in_graph_mode = not context.executing_eagerly()
|
||||||
if not ops.inside_function():
|
if not ops.inside_function():
|
||||||
# If we've been init_scope()d out of the function definition nothing to do
|
# If we've been init_scope()d out of the function definition nothing to do
|
||||||
# here; we can't really do the capturing or conditional logic.
|
# here; we can't really do the capturing or conditional logic.
|
||||||
|
@ -598,7 +598,7 @@ class ConcreteFunction(object):
|
|||||||
return self._call_flat(
|
return self._call_flat(
|
||||||
(t for t in nest.flatten((args, kwargs), expand_composites=True)
|
(t for t in nest.flatten((args, kwargs), expand_composites=True)
|
||||||
if isinstance(t, (ops.Tensor,
|
if isinstance(t, (ops.Tensor,
|
||||||
resource_variable_ops.ResourceVariable))),
|
resource_variable_ops.BaseResourceVariable))),
|
||||||
self.captured_inputs)
|
self.captured_inputs)
|
||||||
|
|
||||||
def _call_flat(self, args, captured_inputs):
|
def _call_flat(self, args, captured_inputs):
|
||||||
@ -632,7 +632,7 @@ class ConcreteFunction(object):
|
|||||||
tensor_inputs = []
|
tensor_inputs = []
|
||||||
variables_used = set([])
|
variables_used = set([])
|
||||||
for i, arg in enumerate(args):
|
for i, arg in enumerate(args):
|
||||||
if isinstance(arg, resource_variable_ops.ResourceVariable):
|
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||||
# We can pass a variable more than once, and in this case we need to
|
# We can pass a variable more than once, and in this case we need to
|
||||||
# pass its handle only once.
|
# pass its handle only once.
|
||||||
if arg.handle in variables_used:
|
if arg.handle in variables_used:
|
||||||
|
@ -156,7 +156,7 @@ def _lift_unlifted_variables(graph, variable_holder):
|
|||||||
def _should_lift_variable(v):
|
def _should_lift_variable(v):
|
||||||
return ((v._in_graph_mode # pylint: disable=protected-access
|
return ((v._in_graph_mode # pylint: disable=protected-access
|
||||||
and v.graph.building_function)
|
and v.graph.building_function)
|
||||||
and isinstance(v, resource_variable_ops.ResourceVariable)
|
and isinstance(v, resource_variable_ops.BaseResourceVariable)
|
||||||
and v.handle not in existing_captures)
|
and v.handle not in existing_captures)
|
||||||
|
|
||||||
for old_variable in global_collection_variables:
|
for old_variable in global_collection_variables:
|
||||||
|
@ -794,7 +794,7 @@ def func_graph_from_py_func(name,
|
|||||||
inputs = []
|
inputs = []
|
||||||
for arg in (nest.flatten(func_args, expand_composites=True) +
|
for arg in (nest.flatten(func_args, expand_composites=True) +
|
||||||
nest.flatten(func_kwargs, expand_composites=True)):
|
nest.flatten(func_kwargs, expand_composites=True)):
|
||||||
if isinstance(arg, resource_variable_ops.ResourceVariable):
|
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||||
# Even if an argument variable was not used in the function, we've
|
# Even if an argument variable was not used in the function, we've
|
||||||
# already manually captured the resource Tensor when creating argument
|
# already manually captured the resource Tensor when creating argument
|
||||||
# placeholders.
|
# placeholders.
|
||||||
@ -1003,7 +1003,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
|||||||
"_user_specified_name",
|
"_user_specified_name",
|
||||||
attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
|
attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
|
||||||
function_inputs.append(placeholder)
|
function_inputs.append(placeholder)
|
||||||
elif isinstance(arg, resource_variable_ops.ResourceVariable):
|
elif isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||||
# Capture arg variables to create placeholders for them. These will be
|
# Capture arg variables to create placeholders for them. These will be
|
||||||
# removed as captures after the function is traced (since otherwise we'd
|
# removed as captures after the function is traced (since otherwise we'd
|
||||||
# just add it back with a new placeholder when the variable was
|
# just add it back with a new placeholder when the variable was
|
||||||
|
@ -790,7 +790,7 @@ class _FuncGraph(ops.Graph):
|
|||||||
collections=collections,
|
collections=collections,
|
||||||
use_resource=use_resource)
|
use_resource=use_resource)
|
||||||
self.extra_vars.append(var)
|
self.extra_vars.append(var)
|
||||||
if (isinstance(var, resource_variable_ops.ResourceVariable) and
|
if (isinstance(var, resource_variable_ops.BaseResourceVariable) and
|
||||||
self._capture_resource_var_by_value):
|
self._capture_resource_var_by_value):
|
||||||
# For resource-based variables read the variable outside the function
|
# For resource-based variables read the variable outside the function
|
||||||
# and pass in the value. This ensures that the function is pure and
|
# and pass in the value. This ensures that the function is pure and
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -38,7 +38,7 @@ from tensorflow.python.util import tf_inspect
|
|||||||
|
|
||||||
|
|
||||||
def _is_tensor(t):
|
def _is_tensor(t):
|
||||||
return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
|
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
|
||||||
|
|
||||||
|
|
||||||
def _call_concrete_function(function, inputs):
|
def _call_concrete_function(function, inputs):
|
||||||
|
@ -604,9 +604,9 @@ class Optimizer(
|
|||||||
# We colocate all ops created in _apply_dense or _apply_sparse
|
# We colocate all ops created in _apply_dense or _apply_sparse
|
||||||
# on the same device as the variable.
|
# on the same device as the variable.
|
||||||
# TODO(apassos): figure out how to get the variable name here.
|
# TODO(apassos): figure out how to get the variable name here.
|
||||||
if context.executing_eagerly() or isinstance(
|
if (context.executing_eagerly() or
|
||||||
var,
|
isinstance(var, resource_variable_ops.BaseResourceVariable)
|
||||||
resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access
|
and not var._in_graph_mode): # pylint: disable=protected-access
|
||||||
scope_name = ""
|
scope_name = ""
|
||||||
else:
|
else:
|
||||||
scope_name = var.op.name
|
scope_name = var.op.name
|
||||||
@ -617,7 +617,8 @@ class Optimizer(
|
|||||||
else:
|
else:
|
||||||
with ops.control_dependencies([self._finish(update_ops, "update")]):
|
with ops.control_dependencies([self._finish(update_ops, "update")]):
|
||||||
with ops.colocate_with(global_step):
|
with ops.colocate_with(global_step):
|
||||||
if isinstance(global_step, resource_variable_ops.ResourceVariable):
|
if isinstance(
|
||||||
|
global_step, resource_variable_ops.BaseResourceVariable):
|
||||||
# TODO(apassos): the implicit read in assign_add is slow; consider
|
# TODO(apassos): the implicit read in assign_add is slow; consider
|
||||||
# making it less so.
|
# making it less so.
|
||||||
apply_updates = resource_variable_ops.assign_add_variable_op(
|
apply_updates = resource_variable_ops.assign_add_variable_op(
|
||||||
|
@ -179,7 +179,7 @@ def saveable_objects_for_op(op, name):
|
|||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
else:
|
else:
|
||||||
# A variable or tensor.
|
# A variable or tensor.
|
||||||
if isinstance(op, resource_variable_ops.ResourceVariable):
|
if isinstance(op, resource_variable_ops.BaseResourceVariable):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
if op._in_graph_mode:
|
if op._in_graph_mode:
|
||||||
variable = op._graph_element
|
variable = op._graph_element
|
||||||
@ -233,7 +233,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
for var in op_list:
|
for var in op_list:
|
||||||
resource_or_ref_variable = (
|
resource_or_ref_variable = (
|
||||||
isinstance(var, resource_variable_ops.ResourceVariable) or
|
isinstance(var, resource_variable_ops.BaseResourceVariable) or
|
||||||
isinstance(var, variables.RefVariable))
|
isinstance(var, variables.RefVariable))
|
||||||
|
|
||||||
if isinstance(var, saveable_object.SaveableObject):
|
if isinstance(var, saveable_object.SaveableObject):
|
||||||
@ -263,7 +263,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
|||||||
# indicating whether they were created in a graph building context. We
|
# indicating whether they were created in a graph building context. We
|
||||||
# also get Tensors when graph building, which do not have this property.
|
# also get Tensors when graph building, which do not have this property.
|
||||||
if not getattr(var, "_in_graph_mode", True):
|
if not getattr(var, "_in_graph_mode", True):
|
||||||
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
if not isinstance(var, resource_variable_ops.BaseResourceVariable):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can only save/restore ResourceVariables when eager execution "
|
"Can only save/restore ResourceVariables when eager execution "
|
||||||
"is enabled, type: %s." % type(var))
|
"is enabled, type: %s." % type(var))
|
||||||
@ -277,7 +277,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
|||||||
(var._shared_name,))
|
(var._shared_name,))
|
||||||
else:
|
else:
|
||||||
if convert_variable_to_tensor:
|
if convert_variable_to_tensor:
|
||||||
if isinstance(var, resource_variable_ops.ResourceVariable):
|
if isinstance(var, resource_variable_ops.BaseResourceVariable):
|
||||||
var = var._graph_element # pylint: disable=protected-access
|
var = var._graph_element # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user