Clean up the ResourceVariable inheritance hierarchy a bit.

PiperOrigin-RevId: 253592362
This commit is contained in:
Alexandre Passos 2019-06-17 09:08:06 -07:00 committed by TensorFlower Gardener
parent a5b860a7cc
commit aa93ea6441
13 changed files with 649 additions and 585 deletions

View File

@ -63,7 +63,7 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
return handle
class SharedVariable(resource_variable_ops.ResourceVariable):
class SharedVariable(resource_variable_ops.BaseResourceVariable):
"""Experimental Variable designed for parameter server training.
A SharedVariable has a name and two instances of SharedVariable with the
@ -231,7 +231,7 @@ class SharedVariable(resource_variable_ops.ResourceVariable):
self._graph_element = None
self._cached_value = None
self._handle_deleter = None
self._handle_deleter = object()
self._cached_shape_as_list = None

View File

@ -48,7 +48,7 @@ def check_destinations(destinations):
Boolean which is True if `destinations` is not empty.
"""
# Calling bool() on a ResourceVariable is not allowed.
if isinstance(destinations, resource_variable_ops.ResourceVariable):
if isinstance(destinations, resource_variable_ops.BaseResourceVariable):
return bool(destinations.device)
return bool(destinations)
@ -56,7 +56,7 @@ def check_destinations(destinations):
def validate_destinations(destinations):
if not isinstance(destinations,
(value_lib.DistributedValues,
resource_variable_ops.ResourceVariable,
resource_variable_ops.BaseResourceVariable,
value_lib.AggregatingVariable,
six.string_types,
value_lib.TPUMirroredVariable,

View File

@ -479,7 +479,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1):
def _update(self, var, fn, args, kwargs, group):
if isinstance(var, values.AggregatingVariable):
var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
if not isinstance(var, resource_variable_ops.BaseResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):

View File

@ -519,7 +519,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
def _update(self, var, fn, args, kwargs, group):
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
var, resource_variable_ops.ResourceVariable)
var, resource_variable_ops.BaseResourceVariable)
if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
if group:
return fn(var, *args, **kwargs)
@ -540,7 +540,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
def read_var(self, var):
assert isinstance(var, values.TPUMirroredVariable) or isinstance(
var, resource_variable_ops.ResourceVariable)
var, resource_variable_ops.BaseResourceVariable)
return var.read_value()
def _local_results(self, val):

View File

@ -111,6 +111,8 @@ class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable):
shape and `validate_shape` is `True`.
RuntimeError: If called outside of a function definition.
"""
with ops.init_scope():
self._in_graph_mode = not context.executing_eagerly()
if not ops.inside_function():
# If we've been init_scope()d out of the function definition nothing to do
# here; we can't really do the capturing or conditional logic.

View File

@ -598,7 +598,7 @@ class ConcreteFunction(object):
return self._call_flat(
(t for t in nest.flatten((args, kwargs), expand_composites=True)
if isinstance(t, (ops.Tensor,
resource_variable_ops.ResourceVariable))),
resource_variable_ops.BaseResourceVariable))),
self.captured_inputs)
def _call_flat(self, args, captured_inputs):
@ -632,7 +632,7 @@ class ConcreteFunction(object):
tensor_inputs = []
variables_used = set([])
for i, arg in enumerate(args):
if isinstance(arg, resource_variable_ops.ResourceVariable):
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
# We can pass a variable more than once, and in this case we need to
# pass its handle only once.
if arg.handle in variables_used:

View File

@ -156,7 +156,7 @@ def _lift_unlifted_variables(graph, variable_holder):
def _should_lift_variable(v):
return ((v._in_graph_mode # pylint: disable=protected-access
and v.graph.building_function)
and isinstance(v, resource_variable_ops.ResourceVariable)
and isinstance(v, resource_variable_ops.BaseResourceVariable)
and v.handle not in existing_captures)
for old_variable in global_collection_variables:

View File

@ -794,7 +794,7 @@ def func_graph_from_py_func(name,
inputs = []
for arg in (nest.flatten(func_args, expand_composites=True) +
nest.flatten(func_kwargs, expand_composites=True)):
if isinstance(arg, resource_variable_ops.ResourceVariable):
if isinstance(arg, resource_variable_ops.BaseResourceVariable):
# Even if an argument variable was not used in the function, we've
# already manually captured the resource Tensor when creating argument
# placeholders.
@ -1003,7 +1003,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None):
"_user_specified_name",
attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name)))
function_inputs.append(placeholder)
elif isinstance(arg, resource_variable_ops.ResourceVariable):
elif isinstance(arg, resource_variable_ops.BaseResourceVariable):
# Capture arg variables to create placeholders for them. These will be
# removed as captures after the function is traced (since otherwise we'd
# just add it back with a new placeholder when the variable was

View File

@ -790,7 +790,7 @@ class _FuncGraph(ops.Graph):
collections=collections,
use_resource=use_resource)
self.extra_vars.append(var)
if (isinstance(var, resource_variable_ops.ResourceVariable) and
if (isinstance(var, resource_variable_ops.BaseResourceVariable) and
self._capture_resource_var_by_value):
# For resource-based variables read the variable outside the function
# and pass in the value. This ensures that the function is pure and

File diff suppressed because it is too large Load Diff

View File

@ -38,7 +38,7 @@ from tensorflow.python.util import tf_inspect
def _is_tensor(t):
return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable))
return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
def _call_concrete_function(function, inputs):

View File

@ -604,9 +604,9 @@ class Optimizer(
# We colocate all ops created in _apply_dense or _apply_sparse
# on the same device as the variable.
# TODO(apassos): figure out how to get the variable name here.
if context.executing_eagerly() or isinstance(
var,
resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access
if (context.executing_eagerly() or
isinstance(var, resource_variable_ops.BaseResourceVariable)
and not var._in_graph_mode): # pylint: disable=protected-access
scope_name = ""
else:
scope_name = var.op.name
@ -617,7 +617,8 @@ class Optimizer(
else:
with ops.control_dependencies([self._finish(update_ops, "update")]):
with ops.colocate_with(global_step):
if isinstance(global_step, resource_variable_ops.ResourceVariable):
if isinstance(
global_step, resource_variable_ops.BaseResourceVariable):
# TODO(apassos): the implicit read in assign_add is slow; consider
# making it less so.
apply_updates = resource_variable_ops.assign_add_variable_op(

View File

@ -179,7 +179,7 @@ def saveable_objects_for_op(op, name):
# pylint: enable=protected-access
else:
# A variable or tensor.
if isinstance(op, resource_variable_ops.ResourceVariable):
if isinstance(op, resource_variable_ops.BaseResourceVariable):
# pylint: disable=protected-access
if op._in_graph_mode:
variable = op._graph_element
@ -233,7 +233,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
# pylint: disable=protected-access
for var in op_list:
resource_or_ref_variable = (
isinstance(var, resource_variable_ops.ResourceVariable) or
isinstance(var, resource_variable_ops.BaseResourceVariable) or
isinstance(var, variables.RefVariable))
if isinstance(var, saveable_object.SaveableObject):
@ -263,7 +263,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
# indicating whether they were created in a graph building context. We
# also get Tensors when graph building, which do not have this property.
if not getattr(var, "_in_graph_mode", True):
if not isinstance(var, resource_variable_ops.ResourceVariable):
if not isinstance(var, resource_variable_ops.BaseResourceVariable):
raise ValueError(
"Can only save/restore ResourceVariables when eager execution "
"is enabled, type: %s." % type(var))
@ -277,7 +277,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True):
(var._shared_name,))
else:
if convert_variable_to_tensor:
if isinstance(var, resource_variable_ops.ResourceVariable):
if isinstance(var, resource_variable_ops.BaseResourceVariable):
var = var._graph_element # pylint: disable=protected-access
else:
var = ops.internal_convert_to_tensor(var, as_ref=True)