Clean up the ResourceVariable inheritance hierarchy a bit.
PiperOrigin-RevId: 253592362
This commit is contained in:
parent
a5b860a7cc
commit
aa93ea6441
@ -63,7 +63,7 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
|
||||
return handle
|
||||
|
||||
|
||||
class SharedVariable(resource_variable_ops.ResourceVariable):
|
||||
class SharedVariable(resource_variable_ops.BaseResourceVariable):
|
||||
"""Experimental Variable designed for parameter server training.
|
||||
|
||||
A SharedVariable has a name and two instances of SharedVariable with the
|
||||
@ -231,7 +231,7 @@ class SharedVariable(resource_variable_ops.ResourceVariable):
|
||||
self._graph_element = None
|
||||
self._cached_value = None
|
||||
|
||||
self._handle_deleter = None
|
||||
self._handle_deleter = object()
|
||||
self._cached_shape_as_list = None
|
||||
|
||||
|
||||
|
@ -48,7 +48,7 @@ def check_destinations(destinations):
|
||||
Boolean which is True if `destinations` is not empty.
|
||||
"""
|
||||
# Calling bool() on a ResourceVariable is not allowed.
|
||||
if isinstance(destinations, resource_variable_ops.ResourceVariable):
|
||||
if isinstance(destinations, resource_variable_ops.BaseResourceVariable):
|
||||
return bool(destinations.device)
|
||||
return bool(destinations)
|
||||
|
||||
@ -56,7 +56,7 @@ def check_destinations(destinations):
|
||||
def validate_destinations(destinations):
|
||||
if not isinstance(destinations,
|
||||
(value_lib.DistributedValues,
|
||||
resource_variable_ops.ResourceVariable,
|
||||
resource_variable_ops.BaseResourceVariable,
|
||||
value_lib.AggregatingVariable,
|
||||
six.string_types,
|
||||
value_lib.TPUMirroredVariable,
|
||||
|
@ -479,7 +479,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
if isinstance(var, values.AggregatingVariable):
|
||||
var = var.get()
|
||||
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
if not isinstance(var, resource_variable_ops.BaseResourceVariable):
|
||||
raise ValueError(
|
||||
"You can not update `var` %r. It must be a Variable." % var)
|
||||
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
|
||||
|
@ -519,7 +519,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
def _update(self, var, fn, args, kwargs, group):
|
||||
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
|
||||
var, resource_variable_ops.ResourceVariable)
|
||||
var, resource_variable_ops.BaseResourceVariable)
|
||||
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
|
||||
if group:
|
||||
return fn(var, *args, **kwargs)
|
||||
@ -540,7 +540,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||
|
||||
def read_var(self, var):
|
||||
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
|
||||
var, resource_variable_ops.ResourceVariable)
|
||||
var, resource_variable_ops.BaseResourceVariable)
|
||||
return var.read_value()
|
||||
|
||||
def _local_results(self, val):
|
||||
|
@ -111,6 +111,8 @@ class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
|
||||
shape and `validate_shape` is `True`.
|
||||
RuntimeError: If called outside of a function definition.
|
||||
"""
|
||||
with ops.init_scope():
|
||||
self._in_graph_mode = not context.executing_eagerly()
|
||||
if not ops.inside_function():
|
||||
# If we've been init_scope()d out of the function definition nothing to do
|
||||
# here; we can't really do the capturing or conditional logic.
|
||||
|
@ -598,7 +598,7 @@ class ConcreteFunction(object):
|
||||
return self._call_flat(
|
||||
(t for t in nest.flatten((args, kwargs), expand_composites=True)
|
||||
if isinstance(t, (ops.Tensor,
|
||||
resource_variable_ops.ResourceVariable))),
|
||||
resource_variable_ops.BaseResourceVariable))),
|
||||
self.captured_inputs)
|
||||
|
||||
def _call_flat(self, args, captured_inputs):
|
||||
@ -632,7 +632,7 @@ class ConcreteFunction(object):
|
||||
tensor_inputs = []
|
||||
variables_used = set([])
|
||||
for i, arg in enumerate(args):
|
||||
if isinstance(arg, resource_variable_ops.ResourceVariable):
|
||||
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||
# We can pass a variable more than once, and in this case we need to
|
||||
# pass its handle only once.
|
||||
if arg.handle in variables_used:
|
||||
|
@ -156,7 +156,7 @@ def _lift_unlifted_variables(graph, variable_holder):
|
||||
def _should_lift_variable(v):
|
||||
return ((v._in_graph_mode # pylint: disable=protected-access
|
||||
and v.graph.building_function)
|
||||
and isinstance(v, resource_variable_ops.ResourceVariable)
|
||||
and isinstance(v, resource_variable_ops.BaseResourceVariable)
|
||||
and v.handle not in existing_captures)
|
||||
|
||||
for old_variable in global_collection_variables:
|
||||
|
@ -794,7 +794,7 @@ def func_graph_from_py_func(name,
|
||||
inputs = []
|
||||
for arg in (nest.flatten(func_args, expand_composites=True) +
|
||||
nest.flatten(func_kwargs, expand_composites=True)):
|
||||
if isinstance(arg, resource_variable_ops.ResourceVariable):
|
||||
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||
# Even if an argument variable was not used in the function, we've
|
||||
# already manually captured the resource Tensor when creating argument
|
||||
# placeholders.
|
||||
@ -1003,7 +1003,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
|
||||
"_user_specified_name",
|
||||
attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
|
||||
function_inputs.append(placeholder)
|
||||
elif isinstance(arg, resource_variable_ops.ResourceVariable):
|
||||
elif isinstance(arg, resource_variable_ops.BaseResourceVariable):
|
||||
# Capture arg variables to create placeholders for them. These will be
|
||||
# removed as captures after the function is traced (since otherwise we'd
|
||||
# just add it back with a new placeholder when the variable was
|
||||
|
@ -790,7 +790,7 @@ class _FuncGraph(ops.Graph):
|
||||
collections=collections,
|
||||
use_resource=use_resource)
|
||||
self.extra_vars.append(var)
|
||||
if (isinstance(var, resource_variable_ops.ResourceVariable) and
|
||||
if (isinstance(var, resource_variable_ops.BaseResourceVariable) and
|
||||
self._capture_resource_var_by_value):
|
||||
# For resource-based variables read the variable outside the function
|
||||
# and pass in the value. This ensures that the function is pure and
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -38,7 +38,7 @@ from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def _is_tensor(t):
|
||||
return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
|
||||
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
|
||||
|
||||
|
||||
def _call_concrete_function(function, inputs):
|
||||
|
@ -604,9 +604,9 @@ class Optimizer(
|
||||
# We colocate all ops created in _apply_dense or _apply_sparse
|
||||
# on the same device as the variable.
|
||||
# TODO(apassos): figure out how to get the variable name here.
|
||||
if context.executing_eagerly() or isinstance(
|
||||
var,
|
||||
resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access
|
||||
if (context.executing_eagerly() or
|
||||
isinstance(var, resource_variable_ops.BaseResourceVariable)
|
||||
and not var._in_graph_mode): # pylint: disable=protected-access
|
||||
scope_name = ""
|
||||
else:
|
||||
scope_name = var.op.name
|
||||
@ -617,7 +617,8 @@ class Optimizer(
|
||||
else:
|
||||
with ops.control_dependencies([self._finish(update_ops, "update")]):
|
||||
with ops.colocate_with(global_step):
|
||||
if isinstance(global_step, resource_variable_ops.ResourceVariable):
|
||||
if isinstance(
|
||||
global_step, resource_variable_ops.BaseResourceVariable):
|
||||
# TODO(apassos): the implicit read in assign_add is slow; consider
|
||||
# making it less so.
|
||||
apply_updates = resource_variable_ops.assign_add_variable_op(
|
||||
|
@ -179,7 +179,7 @@ def saveable_objects_for_op(op, name):
|
||||
# pylint: enable=protected-access
|
||||
else:
|
||||
# A variable or tensor.
|
||||
if isinstance(op, resource_variable_ops.ResourceVariable):
|
||||
if isinstance(op, resource_variable_ops.BaseResourceVariable):
|
||||
# pylint: disable=protected-access
|
||||
if op._in_graph_mode:
|
||||
variable = op._graph_element
|
||||
@ -233,7 +233,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
# pylint: disable=protected-access
|
||||
for var in op_list:
|
||||
resource_or_ref_variable = (
|
||||
isinstance(var, resource_variable_ops.ResourceVariable) or
|
||||
isinstance(var, resource_variable_ops.BaseResourceVariable) or
|
||||
isinstance(var, variables.RefVariable))
|
||||
|
||||
if isinstance(var, saveable_object.SaveableObject):
|
||||
@ -263,7 +263,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
# indicating whether they were created in a graph building context. We
|
||||
# also get Tensors when graph building, which do not have this property.
|
||||
if not getattr(var, "_in_graph_mode", True):
|
||||
if not isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
if not isinstance(var, resource_variable_ops.BaseResourceVariable):
|
||||
raise ValueError(
|
||||
"Can only save/restore ResourceVariables when eager execution "
|
||||
"is enabled, type: %s." % type(var))
|
||||
@ -277,7 +277,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
|
||||
(var._shared_name,))
|
||||
else:
|
||||
if convert_variable_to_tensor:
|
||||
if isinstance(var, resource_variable_ops.ResourceVariable):
|
||||
if isinstance(var, resource_variable_ops.BaseResourceVariable):
|
||||
var = var._graph_element # pylint: disable=protected-access
|
||||
else:
|
||||
var = ops.internal_convert_to_tensor(var, as_ref=True)
|
||||
|
Loading…
Reference in New Issue
Block a user