diff --git a/tensorflow/contrib/eager/python/parameter_server.py b/tensorflow/contrib/eager/python/parameter_server.py index d221d9790a6..c96a03dd999 100644 --- a/tensorflow/contrib/eager/python/parameter_server.py +++ b/tensorflow/contrib/eager/python/parameter_server.py @@ -63,7 +63,7 @@ def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode): return handle -class SharedVariable(resource_variable_ops.ResourceVariable): +class SharedVariable(resource_variable_ops.BaseResourceVariable): """Experimental Variable designed for parameter server training. A SharedVariable has a name and two instances of SharedVariable with the @@ -231,7 +231,7 @@ class SharedVariable(resource_variable_ops.ResourceVariable): self._graph_element = None self._cached_value = None - self._handle_deleter = None + self._handle_deleter = object() self._cached_shape_as_list = None diff --git a/tensorflow/python/distribute/cross_device_ops.py b/tensorflow/python/distribute/cross_device_ops.py index e3b8fc0cada..67578ba0c6f 100644 --- a/tensorflow/python/distribute/cross_device_ops.py +++ b/tensorflow/python/distribute/cross_device_ops.py @@ -48,7 +48,7 @@ def check_destinations(destinations): Boolean which is True if `destinations` is not empty. """ # Calling bool() on a ResourceVariable is not allowed. - if isinstance(destinations, resource_variable_ops.ResourceVariable): + if isinstance(destinations, resource_variable_ops.BaseResourceVariable): return bool(destinations.device) return bool(destinations) @@ -56,7 +56,7 @@ def check_destinations(destinations): def validate_destinations(destinations): if not isinstance(destinations, (value_lib.DistributedValues, - resource_variable_ops.ResourceVariable, + resource_variable_ops.BaseResourceVariable, value_lib.AggregatingVariable, six.string_types, value_lib.TPUMirroredVariable, diff --git a/tensorflow/python/distribute/parameter_server_strategy.py b/tensorflow/python/distribute/parameter_server_strategy.py index 7e013cb593e..36ba35ea86f 100644 --- a/tensorflow/python/distribute/parameter_server_strategy.py +++ b/tensorflow/python/distribute/parameter_server_strategy.py @@ -479,7 +479,7 @@ class ParameterServerStrategyExtended(distribute_lib.StrategyExtendedV1): def _update(self, var, fn, args, kwargs, group): if isinstance(var, values.AggregatingVariable): var = var.get() - if not isinstance(var, resource_variable_ops.ResourceVariable): + if not isinstance(var, resource_variable_ops.BaseResourceVariable): raise ValueError( "You can not update `var` %r. It must be a Variable." % var) with ops.colocate_with(var), distribute_lib.UpdateContext(var.device): diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index ee41baec291..40bbdd978e9 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -519,7 +519,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): def _update(self, var, fn, args, kwargs, group): assert isinstance(var, values.TPUMirroredVariable) or isinstance( - var, resource_variable_ops.ResourceVariable) + var, resource_variable_ops.BaseResourceVariable) if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access if group: return fn(var, *args, **kwargs) @@ -540,7 +540,7 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): def read_var(self, var): assert isinstance(var, values.TPUMirroredVariable) or isinstance( - var, resource_variable_ops.ResourceVariable) + var, resource_variable_ops.BaseResourceVariable) return var.read_value() def _local_results(self, val): diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py index 9c44a46c2c7..c5571b9bb6a 100644 --- a/tensorflow/python/eager/def_function.py +++ b/tensorflow/python/eager/def_function.py @@ -111,6 +111,8 @@ class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable): shape and `validate_shape` is `True`. RuntimeError: If called outside of a function definition. """ + with ops.init_scope(): + self._in_graph_mode = not context.executing_eagerly() if not ops.inside_function(): # If we've been init_scope()d out of the function definition nothing to do # here; we can't really do the capturing or conditional logic. diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index a9b363ccde8..fe43393a695 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -598,7 +598,7 @@ class ConcreteFunction(object): return self._call_flat( (t for t in nest.flatten((args, kwargs), expand_composites=True) if isinstance(t, (ops.Tensor, - resource_variable_ops.ResourceVariable))), + resource_variable_ops.BaseResourceVariable))), self.captured_inputs) def _call_flat(self, args, captured_inputs): @@ -632,7 +632,7 @@ class ConcreteFunction(object): tensor_inputs = [] variables_used = set([]) for i, arg in enumerate(args): - if isinstance(arg, resource_variable_ops.ResourceVariable): + if isinstance(arg, resource_variable_ops.BaseResourceVariable): # We can pass a variable more than once, and in this case we need to # pass its handle only once. if arg.handle in variables_used: diff --git a/tensorflow/python/eager/wrap_function.py b/tensorflow/python/eager/wrap_function.py index f9f8e379c8e..57c571e47f7 100644 --- a/tensorflow/python/eager/wrap_function.py +++ b/tensorflow/python/eager/wrap_function.py @@ -156,7 +156,7 @@ def _lift_unlifted_variables(graph, variable_holder): def _should_lift_variable(v): return ((v._in_graph_mode # pylint: disable=protected-access and v.graph.building_function) - and isinstance(v, resource_variable_ops.ResourceVariable) + and isinstance(v, resource_variable_ops.BaseResourceVariable) and v.handle not in existing_captures) for old_variable in global_collection_variables: diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index 00c00d789d8..56e83f03cce 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -794,7 +794,7 @@ def func_graph_from_py_func(name, inputs = [] for arg in (nest.flatten(func_args, expand_composites=True) + nest.flatten(func_kwargs, expand_composites=True)): - if isinstance(arg, resource_variable_ops.ResourceVariable): + if isinstance(arg, resource_variable_ops.BaseResourceVariable): # Even if an argument variable was not used in the function, we've # already manually captured the resource Tensor when creating argument # placeholders. @@ -1003,7 +1003,7 @@ def _get_defun_inputs(args, names, structure, flat_shapes=None): "_user_specified_name", attr_value_pb2.AttrValue(s=compat.as_bytes(requested_name))) function_inputs.append(placeholder) - elif isinstance(arg, resource_variable_ops.ResourceVariable): + elif isinstance(arg, resource_variable_ops.BaseResourceVariable): # Capture arg variables to create placeholders for them. These will be # removed as captures after the function is traced (since otherwise we'd # just add it back with a new placeholder when the variable was diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 20b8a5608ff..125879316a4 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -790,7 +790,7 @@ class _FuncGraph(ops.Graph): collections=collections, use_resource=use_resource) self.extra_vars.append(var) - if (isinstance(var, resource_variable_ops.ResourceVariable) and + if (isinstance(var, resource_variable_ops.BaseResourceVariable) and self._capture_resource_var_by_value): # For resource-based variables read the variable outside the function # and pass in the value. This ensures that the function is pure and diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 787a6bd96a2..c017a7f037e 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -85,7 +85,7 @@ def _set_handle_shapes_and_types(tensor, handle_data, graph_mode): shapes, types = zip(*[(pair.shape, pair.dtype) for pair in handle_data.shape_and_type]) ranks = [len(s.dim) if not s.unknown_rank else -1 for s in shapes] - shapes = [[d.size for d in s.dim] + shapes = [[d.size for d in s.dim] # pylint: disable=g-complex-comprehension if not s.unknown_rank else None for s in shapes] pywrap_tensorflow.TF_GraphSetOutputHandleShapesAndTypes_wrapper( tensor._op._graph._c_graph, # pylint: disable=protected-access @@ -313,106 +313,38 @@ def variable_accessed(variable): tape.variable_accessed(variable) -class ResourceVariable(variables.VariableV1): - """Variable based on resource handles. +class BaseResourceVariable(variables.VariableV1): + """A python variable from an existing handle.""" - See the [Variables How To](https://tensorflow.org/guide/variables) - for a high level overview. - - A `ResourceVariable` allows you to maintain state across subsequent calls to - session.run. - - The `ResourceVariable` constructor requires an initial value for the variable, - which can be a `Tensor` of any type and shape. The initial value defines the - type and shape of the variable. After construction, the type and shape of - the variable are fixed. The value can be changed using one of the assign - methods. - - Just like any `Tensor`, variables created with - `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the - graph. Additionally, all the operators overloaded for the `Tensor` class are - carried over to variables, so you can also add nodes to the graph by just - doing arithmetic on variables. - - Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each - usage of a ResourceVariable in a TensorFlow graph adds a read_value operation - to the graph. The Tensors returned by a read_value operation are guaranteed to - see all modifications to the value of the variable which happen in any - operation on which the read_value depends on (either directly, indirectly, or - via a control dependency) and guaranteed to not see any modification to the - value of the variable from operations that depend on the read_value operation. - Updates from operations that have no dependency relationship to the read_value - operation might or might not be visible to read_value. - - For example, if there is more than one assignment to a ResourceVariable in - a single session.run call there is a well-defined value for each operation - which uses the variable's value if the assignments and the read are connected - by edges in the graph. Consider the following example, in which two writes - can cause tf.Variable and tf.ResourceVariable to behave differently: - - ```python - a = tf.Variable(1.0, use_resource=True) - a.initializer.run() - - assign = a.assign(2.0) - with tf.control_dependencies([assign]): - b = a.read_value() - with tf.control_dependencies([b]): - other_assign = a.assign(3.0) - with tf.control_dependencies([other_assign]): - # Will print 2.0 because the value was read before other_assign ran. If - # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. - tf.compat.v1.Print(b, [b]).eval() - ``` - """ - - def __init__(self, - initial_value=None, - trainable=None, - collections=None, - validate_shape=True, # pylint: disable=unused-argument - caching_device=None, - name=None, - dtype=None, - variable_def=None, - import_scope=None, - constraint=None, - distribute_strategy=None, - synchronization=None, - aggregation=None, - shape=None): - """Creates a variable. + def __init__( # pylint: disable=super-init-not-called + self, + trainable=None, + shape=None, + dtype=None, + handle=None, + constraint=None, + synchronization=None, + aggregation=None, + distribute_strategy=None, + name=None, + unique_id=None, + handle_name=None, + graph_element=None, + initial_value=None, + initializer_op=None, + is_initialized_op=None, + cached_value=None, + save_slice_info=None, + handle_deleter=None, + **unused_kwargs): + """Creates a variable from a handle. Args: - initial_value: A `Tensor`, or Python object convertible to a `Tensor`, - which is the initial value for the Variable. Can also be a - callable with no argument that returns the initial value when called. - (Note that initializer functions from init_ops.py must first be bound - to a shape before being used here.) - trainable: If `True`, the default, also adds the variable to the graph - collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as - the default list of variables to use by the `Optimizer` classes. - Defaults to `True`, unless `synchronization` is set to `ON_READ`, in - which case it defaults to `False`. - collections: List of graph collections keys. The new variable is added to - these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. - validate_shape: Ignored. Provided for compatibility with tf.Variable. - caching_device: Optional device string or function describing where the - Variable should be cached for reading. Defaults to the Variable's - device. If not `None`, caches on another device. Typical use is to - cache on the device where the Ops using the Variable reside, to - deduplicate copying through `Switch` and other conditional statements. - name: Optional name for the variable. Defaults to `'Variable'` and gets - uniquified automatically. - dtype: If set, initial_value will be converted to the given type. - If None, either the datatype will be kept (if initial_value is - a Tensor) or float32 will be used (if it is a Python object convertible - to a Tensor). - variable_def: `VariableDef` protocol buffer. If not None, recreates the - `ResourceVariable` object with its contents. `variable_def` and other - arguments (except for import_scope) are mutually exclusive. - import_scope: Optional `string`. Name scope to add to the - ResourceVariable. Only used when `variable_def` is provided. + trainable: If `True`, GradientTapes automatically watch uses of this + Variable. + shape: The variable's shape. + dtype: The variable's dtype. + handle: The variable's handle constraint: An optional projection function to be applied to the variable after being updated by an `Optimizer` (e.g. used to implement norm constraints or value constraints for layer weights). The function must @@ -420,8 +352,6 @@ class ResourceVariable(variables.VariableV1): variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. - distribute_strategy: The tf.distribute.Strategy this variable is being - created inside of. synchronization: Indicates when a distributed a variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to @@ -430,42 +360,59 @@ class ResourceVariable(variables.VariableV1): aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableAggregation`. - shape: (optional) The shape of this variable. If None, the shape of - `initial_value` will be used. When setting this argument to - `tf.TensorShape(None)` (representing an unspecified shape), the variable - can be assigned with values of different shapes. - - Raises: - ValueError: If the initial value is not specified, or does not have a - shape and `validate_shape` is `True`. - - @compatibility(eager) - When Eager Execution is enabled, the default for the `collections` argument - is `None`, which signifies that this `Variable` will not be added to any - collections. - @end_compatibility + distribute_strategy: The distribution strategy this variable was created + under. + name: The name for this variable. + unique_id: Internal. Unique ID for this variable's handle. + handle_name: The name for the variable's handle. + graph_element: Optional, required only in session.run-mode. Pre-created + tensor which reads this variable's value. + initial_value: Optional. Variable's initial value. + initializer_op: Operation which assigns the variable's initial value. + is_initialized_op: Pre-created operation to check whether this variable + is initialized. + cached_value: Pre-created operation to read this variable in a specific + device. + save_slice_info: Metadata for variable partitioning. + handle_deleter: EagerResourceDeleter responsible for cleaning up the + handle. """ + with ops.init_scope(): + self._in_graph_mode = not context.executing_eagerly() + synchronization, aggregation, trainable = ( + variables.validate_synchronization_aggregation_trainable( + synchronization, aggregation, trainable, name)) + self._trainable = trainable + self._synchronization = synchronization + self._aggregation = aggregation + self._save_slice_info = save_slice_info + self._initial_value = initial_value + self._initializer_op = initializer_op + self._is_initialized_op = is_initialized_op + self._graph_element = graph_element + self._cached_value = cached_value self._distribute_strategy = distribute_strategy - if variable_def: - if initial_value is not None: - raise ValueError("variable_def and initial_value are mutually " - "exclusive.") - if context.executing_eagerly(): - raise ValueError("Creating ResourceVariable from variable_def is " - "not supported when eager execution is enabled.") - self._init_from_proto(variable_def, import_scope=import_scope) - else: - self._init_from_args( - initial_value=initial_value, - trainable=trainable, - collections=collections, - caching_device=caching_device, - name=name, - dtype=dtype, - constraint=constraint, - synchronization=synchronization, - aggregation=aggregation, - shape=shape) + # Store the graph key so optimizers know how to only retrieve variables from + # this graph. Guaranteed to be the same as the eager graph_key. + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + self._shape = tensor_shape.as_shape(shape) + self._dtype = dtypes.as_dtype(dtype) + self._handle = handle + self._graph_element = graph_element + self._unique_id = unique_id + self._handle_name = handle_name + ":0" + self._constraint = constraint + # After the handle has been created, set up a way to clean it up when + # executing eagerly. We'll hold the only reference to the deleter, so that + # when this object is garbage collected the deleter will be too. This + # means ResourceVariables can be part of reference cycles without those + # cycles being uncollectable. + if not self._in_graph_mode: + if handle_deleter is None: + handle_deleter = EagerResourceDeleter( + handle=self._handle, handle_device=self._handle.device) + self._handle_deleter = handle_deleter + self._cached_shape_as_list = None def __repr__(self): if context.executing_eagerly() and not self._in_graph_mode: @@ -476,293 +423,6 @@ class ResourceVariable(variables.VariableV1): return "" % ( self.name, self.get_shape(), self.dtype.name) - def _init_from_args(self, - initial_value=None, - trainable=None, - collections=None, - caching_device=None, - name=None, - dtype=None, - constraint=None, - synchronization=None, - aggregation=None, - shape=None): - """Creates a variable. - - Args: - initial_value: A `Tensor`, or Python object convertible to a `Tensor`, - which is the initial value for the Variable. The initial value must have - a shape specified unless `validate_shape` is set to False. Can also be a - callable with no argument that returns the initial value when called. - (Note that initializer functions from init_ops.py must first be bound - to a shape before being used here.) - trainable: If `True`, the default, also adds the variable to the graph - collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as - the default list of variables to use by the `Optimizer` classes. - Defaults to `True`, unless `synchronization` is set to `ON_READ`, in - which case it defaults to `False`. - collections: List of graph collections keys. The new variable is added to - these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. - caching_device: Optional device string or function describing where the - Variable should be cached for reading. Defaults to the Variable's - device. If not `None`, caches on another device. Typical use is to - cache on the device where the Ops using the Variable reside, to - deduplicate copying through `Switch` and other conditional statements. - name: Optional name for the variable. Defaults to `'Variable'` and gets - uniquified automatically. - dtype: If set, initial_value will be converted to the given type. - If None, either the datatype will be kept (if initial_value is - a Tensor) or float32 will be used (if it is a Python object convertible - to a Tensor). - constraint: An optional projection function to be applied to the variable - after being updated by an `Optimizer` (e.g. used to implement norm - constraints or value constraints for layer weights). The function must - take as input the unprojected Tensor representing the value of the - variable and return the Tensor for the projected value - (which must have the same shape). Constraints are not safe to - use when doing asynchronous distributed training. - synchronization: Indicates when a distributed a variable will be - aggregated. Accepted values are constants defined in the class - `tf.VariableSynchronization`. By default the synchronization is set to - `AUTO` and the current `DistributionStrategy` chooses - when to synchronize. - aggregation: Indicates how a distributed variable will be aggregated. - Accepted values are constants defined in the class - `tf.VariableAggregation`. - shape: (optional) The shape of this variable. If None, the shape of - `initial_value` will be used. When setting this argument to - `tf.TensorShape(None)` (representing an unspecified shape), the variable - can be assigned with values of different shapes. - - Raises: - ValueError: If the initial value is not specified, or does not have a - shape and `validate_shape` is `True`. - - @compatibility(eager) - When Eager Execution is enabled, variables are never added to collections. - It is not implicitly added to the `GLOBAL_VARIABLES` or - `TRAINABLE_VARIABLES` collections, and the `collections` argument is - ignored. - @end_compatibility - """ - if initial_value is None: - raise ValueError("initial_value must be specified.") - init_from_fn = callable(initial_value) - - if isinstance(initial_value, ops.Tensor) and hasattr( - initial_value, "graph") and initial_value.graph.building_function: - raise ValueError("Tensor-typed variable initializers must either be " - "wrapped in an init_scope or callable " - "(e.g., `tf.Variable(lambda : " - "tf.truncated_normal([10, 40]))`) when building " - "functions. Please file a feature request if this " - "restriction inconveniences you.") - - if collections is None: - collections = [ops.GraphKeys.GLOBAL_VARIABLES] - if not isinstance(collections, (list, tuple, set)): - raise ValueError( - "collections argument to Variable constructor must be a list, tuple, " - "or set. Got %s of type %s" % (collections, type(collections))) - if constraint is not None and not callable(constraint): - raise ValueError("The `constraint` argument must be a callable.") - - if isinstance(initial_value, trackable.CheckpointInitialValue): - self._maybe_initialize_trackable() - self._update_uid = initial_value.checkpoint_position.restore_uid - initial_value = initial_value.wrapped_value - - synchronization, aggregation, trainable = ( - variables.validate_synchronization_aggregation_trainable( - synchronization, aggregation, trainable, name)) - self._synchronization = synchronization - self._aggregation = aggregation - self._trainable = trainable - if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: - collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] - self._save_slice_info = None - # Store the graph key so optimizers know how to only retrieve variables from - # this graph. - self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - with ops.init_scope(): - self._in_graph_mode = not context.executing_eagerly() - with ops.name_scope(name, "Variable", [] - if init_from_fn else [initial_value]) as name: - # pylint: disable=protected-access - handle_name = ops.name_from_scope_name(name) - if self._in_graph_mode: - shared_name = handle_name - unique_id = shared_name - else: - # When in eager mode use a uid for the shared_name, to prevent - # accidental sharing. - unique_id = "%s_%d" % (handle_name, ops.uid()) - shared_name = context.shared_name() - # Use attr_scope and device(None) to simulate the behavior of - # colocate_with when the variable we want to colocate with doesn't - # yet exist. - device_context_manager = ( - ops.device if self._in_graph_mode else ops.NullContextmanager) - attr = attr_value_pb2.AttrValue( - list=attr_value_pb2.AttrValue.ListValue( - s=[compat.as_bytes("loc:@%s" % handle_name)])) - with ops.get_default_graph()._attr_scope({"_class": attr}): - with ops.name_scope("Initializer"), device_context_manager(None): - initial_value = ops.convert_to_tensor( - initial_value() if init_from_fn else initial_value, - name="initial_value", dtype=dtype) - # Don't use `shape or initial_value.shape` since TensorShape has - # overridden `__bool__`. - self._shape = shape if shape is not None else initial_value.shape - self._handle = eager_safe_variable_handle( - initial_value=initial_value, - shape=self._shape, - shared_name=shared_name, - name=name, - graph_mode=self._in_graph_mode) - # pylint: disable=protected-access - if (self._in_graph_mode and initial_value is not None and - initial_value.op._get_control_flow_context() is not None): - raise ValueError( - "Initializer for variable %s is from inside a control-flow " - "construct, such as a loop or conditional. When creating a " - "variable inside a loop or conditional, use a lambda as the " - "initializer." % name) - # pylint: enable=protected-access - self._unique_id = unique_id - self._initial_value = initial_value if self._in_graph_mode else None - self._handle_name = handle_name + ":0" - self._dtype = initial_value.dtype.base_dtype - self._constraint = constraint - - if self._in_graph_mode: - with ops.name_scope("IsInitialized"): - self._is_initialized_op = ( - gen_resource_variable_ops.var_is_initialized_op(self._handle)) - if initial_value is not None: - with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): - # pylint: disable=protected-access - self._initializer_op = ( - gen_resource_variable_ops.assign_variable_op( - self._handle, - variables._try_guard_against_uninitialized_dependencies( - name, - initial_value), - name=n)) - # pylint: enable=protected-access - with ops.name_scope("Read"), ops.colocate_with(self._handle): - # Manually assign reads to the handle's device to avoid log - # messages. - with ops.device(self._handle.device): - value = self._read_variable_op() - self._graph_element = value - if caching_device is not None: - # Variables may be created in a tf.device() or ops.colocate_with() - # context. At the same time, users would expect caching device to - # be independent of this context, and/or would not expect the - # current device context to be merged with the caching device - # spec. Therefore we reset the colocation stack before creating - # the cached value. Note that resetting the colocation stack will - # also reset the device stack. - with ops.colocate_with(None, ignore_existing=True): - with ops.device(caching_device): - self._cached_value = array_ops.identity(value) - else: - self._cached_value = None - else: - gen_resource_variable_ops.assign_variable_op(self._handle, - initial_value) - self._is_initialized_op = None - self._initializer_op = None - self._graph_element = None - if caching_device: - with ops.device(caching_device): - self._cached_value = self._read_variable_op() - else: - self._cached_value = None - if not context.executing_eagerly(): - # Eager variables are only added to collections if they are part of an - # eager variable store (otherwise in an interactive session they would - # hog memory and cause OOM). This is done in ops/variable_scope.py. - ops.add_to_collections(collections, self) - elif ops.GraphKeys.GLOBAL_STEP in collections: - ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) - - if not self._in_graph_mode: - # After the handle has been created, set up a way to clean it up when - # executing eagerly. We'll hold the only reference to the deleter, so that - # when this object is garbage collected the deleter will be too. This - # means ResourceVariables can be part of reference cycles without those - # cycles being uncollectable, and means that no __del__ will be defined at - # all in graph mode. - self._handle_deleter = EagerResourceDeleter( - handle=self._handle, handle_device=self._handle.device) - - def _init_from_proto(self, variable_def, import_scope=None): - """Initializes from `VariableDef` proto.""" - # Note that init_from_proto is currently not supported in Eager mode. - assert not context.executing_eagerly() - self._in_graph_mode = True - assert isinstance(variable_def, variable_pb2.VariableDef) - if not variable_def.is_resource: - raise ValueError("Trying to restore Variable as ResourceVariable.") - - # Create from variable_def. - g = ops.get_default_graph() - self._handle = g.as_graph_element( - ops.prepend_name_scope( - variable_def.variable_name, import_scope=import_scope)) - self._shape = tensor_shape.TensorShape( - self._handle.op.get_attr("shape")) - self._handle_name = self._handle.name - self._unique_id = self._handle_name - self._initializer_op = g.as_graph_element( - ops.prepend_name_scope( - variable_def.initializer_name, import_scope=import_scope)) - # Check whether initial_value_name exists for backwards compatibility. - if (hasattr(variable_def, "initial_value_name") and - variable_def.initial_value_name): - self._initial_value = g.as_graph_element( - ops.prepend_name_scope(variable_def.initial_value_name, - import_scope=import_scope)) - else: - self._initial_value = None - synchronization, aggregation, trainable = ( - variables.validate_synchronization_aggregation_trainable( - variable_def.synchronization, - variable_def.aggregation, - variable_def.trainable, - variable_def.variable_name)) - self._synchronization = synchronization - self._aggregation = aggregation - self._trainable = trainable - if variable_def.snapshot_name: - snapshot = g.as_graph_element( - ops.prepend_name_scope( - variable_def.snapshot_name, import_scope=import_scope)) - if snapshot.op.type != "ReadVariableOp": - self._cached_value = snapshot - else: - self._cached_value = None - while snapshot.op.type != "ReadVariableOp": - snapshot = snapshot.op.inputs[0] - self._graph_element = snapshot - else: - self._cached_value = None - # Legacy case for protos without the snapshot name; assume it's the - # following. - self._graph_element = g.get_tensor_by_name( - self._handle.op.name + "/Read/ReadVariableOp:0") - if variable_def.HasField("save_slice_info_def"): - self._save_slice_info = variables.Variable.SaveSliceInfo( - save_slice_info_def=variable_def.save_slice_info_def, - import_scope=import_scope) - else: - self._save_slice_info = None - self._caching_device = None - self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) - self._constraint = None - @contextlib.contextmanager def _assign_dependencies(self): """Makes assignments depend on the cached value, if any. @@ -1586,6 +1246,527 @@ class ResourceVariable(variables.VariableV1): "`var = var ** value` to get a new Tensor object.") +class ResourceVariable(BaseResourceVariable): + """Variable based on resource handles. + + See the [Variables How To](https://tensorflow.org/guide/variables) + for a high level overview. + + A `ResourceVariable` allows you to maintain state across subsequent calls to + session.run. + + The `ResourceVariable` constructor requires an initial value for the variable, + which can be a `Tensor` of any type and shape. The initial value defines the + type and shape of the variable. After construction, the type and shape of + the variable are fixed. The value can be changed using one of the assign + methods. + + Just like any `Tensor`, variables created with + `tf.Variable(use_resource=True)` can be used as inputs for other Ops in the + graph. Additionally, all the operators overloaded for the `Tensor` class are + carried over to variables, so you can also add nodes to the graph by just + doing arithmetic on variables. + + Unlike ref-based variable, a ResourceVariable has well-defined semantics. Each + usage of a ResourceVariable in a TensorFlow graph adds a read_value operation + to the graph. The Tensors returned by a read_value operation are guaranteed to + see all modifications to the value of the variable which happen in any + operation on which the read_value depends on (either directly, indirectly, or + via a control dependency) and guaranteed to not see any modification to the + value of the variable from operations that depend on the read_value operation. + Updates from operations that have no dependency relationship to the read_value + operation might or might not be visible to read_value. + + For example, if there is more than one assignment to a ResourceVariable in + a single session.run call there is a well-defined value for each operation + which uses the variable's value if the assignments and the read are connected + by edges in the graph. Consider the following example, in which two writes + can cause tf.Variable and tf.ResourceVariable to behave differently: + + ```python + a = tf.Variable(1.0, use_resource=True) + a.initializer.run() + + assign = a.assign(2.0) + with tf.control_dependencies([assign]): + b = a.read_value() + with tf.control_dependencies([b]): + other_assign = a.assign(3.0) + with tf.control_dependencies([other_assign]): + # Will print 2.0 because the value was read before other_assign ran. If + # `a` was a tf.Variable instead, 2.0 or 3.0 could be printed. + tf.compat.v1.Print(b, [b]).eval() + ``` + """ + + def __init__(self, # pylint: disable=super-init-not-called + initial_value=None, + trainable=None, + collections=None, + validate_shape=True, # pylint: disable=unused-argument + caching_device=None, + name=None, + dtype=None, + variable_def=None, + import_scope=None, + constraint=None, + distribute_strategy=None, + synchronization=None, + aggregation=None, + shape=None): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound + to a shape before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + Defaults to `True`, unless `synchronization` is set to `ON_READ`, in + which case it defaults to `False`. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + validate_shape: Ignored. Provided for compatibility with tf.Variable. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If None, either the datatype will be kept (if initial_value is + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). + variable_def: `VariableDef` protocol buffer. If not None, recreates the + `ResourceVariable` object with its contents. `variable_def` and other + arguments (except for import_scope) are mutually exclusive. + import_scope: Optional `string`. Name scope to add to the + ResourceVariable. Only used when `variable_def` is provided. + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + distribute_strategy: The tf.distribute.Strategy this variable is being + created inside of. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. + shape: (optional) The shape of this variable. If None, the shape of + `initial_value` will be used. When setting this argument to + `tf.TensorShape(None)` (representing an unspecified shape), the variable + can be assigned with values of different shapes. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + + @compatibility(eager) + When Eager Execution is enabled, the default for the `collections` argument + is `None`, which signifies that this `Variable` will not be added to any + collections. + @end_compatibility + """ + if variable_def: + if initial_value is not None: + raise ValueError("variable_def and initial_value are mutually " + "exclusive.") + if context.executing_eagerly(): + raise ValueError("Creating ResourceVariable from variable_def is " + "not supported when eager execution is enabled.") + self._init_from_proto(variable_def, import_scope=import_scope) + else: + self._init_from_args( + initial_value=initial_value, + trainable=trainable, + collections=collections, + caching_device=caching_device, + name=name, + dtype=dtype, + constraint=constraint, + synchronization=synchronization, + aggregation=aggregation, + shape=shape, + distribute_strategy=distribute_strategy) + + def _init_from_args(self, + initial_value=None, + trainable=None, + collections=None, + caching_device=None, + name=None, + dtype=None, + constraint=None, + synchronization=None, + aggregation=None, + distribute_strategy=None, + shape=None): + """Creates a variable. + + Args: + initial_value: A `Tensor`, or Python object convertible to a `Tensor`, + which is the initial value for the Variable. The initial value must have + a shape specified unless `validate_shape` is set to False. Can also be a + callable with no argument that returns the initial value when called. + (Note that initializer functions from init_ops.py must first be bound + to a shape before being used here.) + trainable: If `True`, the default, also adds the variable to the graph + collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as + the default list of variables to use by the `Optimizer` classes. + Defaults to `True`, unless `synchronization` is set to `ON_READ`, in + which case it defaults to `False`. + collections: List of graph collections keys. The new variable is added to + these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + dtype: If set, initial_value will be converted to the given type. + If None, either the datatype will be kept (if initial_value is + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. + distribute_strategy: DistributionStrategy under which this variable + was created. + shape: (optional) The shape of this variable. If None, the shape of + `initial_value` will be used. When setting this argument to + `tf.TensorShape(None)` (representing an unspecified shape), the variable + can be assigned with values of different shapes. + + Raises: + ValueError: If the initial value is not specified, or does not have a + shape and `validate_shape` is `True`. + + @compatibility(eager) + When Eager Execution is enabled, variables are never added to collections. + It is not implicitly added to the `GLOBAL_VARIABLES` or + `TRAINABLE_VARIABLES` collections, and the `collections` argument is + ignored. + @end_compatibility + """ + synchronization, aggregation, trainable = ( + variables.validate_synchronization_aggregation_trainable( + synchronization, aggregation, trainable, name)) + if initial_value is None: + raise ValueError("initial_value must be specified.") + init_from_fn = callable(initial_value) + + if isinstance(initial_value, ops.Tensor) and hasattr( + initial_value, "graph") and initial_value.graph.building_function: + raise ValueError("Tensor-typed variable initializers must either be " + "wrapped in an init_scope or callable " + "(e.g., `tf.Variable(lambda : " + "tf.truncated_normal([10, 40]))`) when building " + "functions. Please file a feature request if this " + "restriction inconveniences you.") + + if collections is None: + collections = [ops.GraphKeys.GLOBAL_VARIABLES] + if not isinstance(collections, (list, tuple, set)): + raise ValueError( + "collections argument to Variable constructor must be a list, tuple, " + "or set. Got %s of type %s" % (collections, type(collections))) + if constraint is not None and not callable(constraint): + raise ValueError("The `constraint` argument must be a callable.") + + if isinstance(initial_value, trackable.CheckpointInitialValue): + self._maybe_initialize_trackable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + + if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: + collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] + with ops.init_scope(): + self._in_graph_mode = not context.executing_eagerly() + with ops.name_scope(name, "Variable", [] + if init_from_fn else [initial_value]) as name: + # pylint: disable=protected-access + handle_name = ops.name_from_scope_name(name) + if self._in_graph_mode: + shared_name = handle_name + unique_id = shared_name + else: + # When in eager mode use a uid for the shared_name, to prevent + # accidental sharing. + unique_id = "%s_%d" % (handle_name, ops.uid()) + shared_name = context.shared_name() + # Use attr_scope and device(None) to simulate the behavior of + # colocate_with when the variable we want to colocate with doesn't + # yet exist. + device_context_manager = ( + ops.device if self._in_graph_mode else ops.NullContextmanager) + attr = attr_value_pb2.AttrValue( + list=attr_value_pb2.AttrValue.ListValue( + s=[compat.as_bytes("loc:@%s" % handle_name)])) + with ops.get_default_graph()._attr_scope({"_class": attr}): + with ops.name_scope("Initializer"), device_context_manager(None): + initial_value = ops.convert_to_tensor( + initial_value() if init_from_fn else initial_value, + name="initial_value", dtype=dtype) + # Don't use `shape or initial_value.shape` since TensorShape has + # overridden `__bool__`. + shape = shape if shape is not None else initial_value.shape + handle = eager_safe_variable_handle( + initial_value=initial_value, + shape=shape, + shared_name=shared_name, + name=name, + graph_mode=self._in_graph_mode) + # pylint: disable=protected-access + if (self._in_graph_mode and initial_value is not None and + initial_value.op._get_control_flow_context() is not None): + raise ValueError( + "Initializer for variable %s is from inside a control-flow " + "construct, such as a loop or conditional. When creating a " + "variable inside a loop or conditional, use a lambda as the " + "initializer." % name) + # pylint: enable=protected-access + dtype = initial_value.dtype.base_dtype + + if self._in_graph_mode: + with ops.name_scope("IsInitialized"): + is_initialized_op = ( + gen_resource_variable_ops.var_is_initialized_op(handle)) + if initial_value is not None: + with ops.name_scope("Assign") as n, ops.colocate_with(handle): + # pylint: disable=protected-access + initializer_op = ( + gen_resource_variable_ops.assign_variable_op( + handle, + variables._try_guard_against_uninitialized_dependencies( + name, + initial_value), + name=n)) + # pylint: enable=protected-access + with ops.name_scope("Read"), ops.colocate_with(handle): + # Manually assign reads to the handle's device to avoid log + # messages. + with ops.device(handle.device): + value = gen_resource_variable_ops.read_variable_op(handle, dtype) + _maybe_set_handle_data(dtype, handle, value) + graph_element = value + if caching_device is not None: + # Variables may be created in a tf.device() or ops.colocate_with() + # context. At the same time, users would expect caching device to + # be independent of this context, and/or would not expect the + # current device context to be merged with the caching device + # spec. Therefore we reset the colocation stack before creating + # the cached value. Note that resetting the colocation stack will + # also reset the device stack. + with ops.colocate_with(None, ignore_existing=True): + with ops.device(caching_device): + cached_value = array_ops.identity(value) + else: + cached_value = None + else: + gen_resource_variable_ops.assign_variable_op(handle, initial_value) + is_initialized_op = None + initializer_op = None + graph_element = None + if caching_device: + with ops.device(caching_device): + cached_value = gen_resource_variable_ops.read_variable_op( + handle, dtype) + _maybe_set_handle_data(dtype, handle, cached_value) + else: + cached_value = None + if not context.executing_eagerly(): + # Eager variables are only added to collections if they are part of an + # eager variable store (otherwise in an interactive session they would + # hog memory and cause OOM). This is done in ops/variable_scope.py. + ops.add_to_collections(collections, self) + elif ops.GraphKeys.GLOBAL_STEP in collections: + ops.add_to_collections(ops.GraphKeys.GLOBAL_STEP, self) + initial_value = initial_value if self._in_graph_mode else None + super(ResourceVariable, self).__init__( + trainable=trainable, shape=shape, dtype=dtype, handle=handle, + synchronization=synchronization, constraint=constraint, + aggregation=aggregation, distribute_strategy=distribute_strategy, + name=name, unique_id=unique_id, handle_name=handle_name, + graph_element=graph_element, initial_value=initial_value, + initializer_op=initializer_op, is_initialized_op=is_initialized_op, + cached_value=cached_value) + + def _init_from_proto(self, variable_def, import_scope=None): + """Initializes from `VariableDef` proto.""" + # Note that init_from_proto is currently not supported in Eager mode. + assert not context.executing_eagerly() + self._in_graph_mode = True + assert isinstance(variable_def, variable_pb2.VariableDef) + if not variable_def.is_resource: + raise ValueError("Trying to restore Variable as ResourceVariable.") + + # Create from variable_def. + g = ops.get_default_graph() + self._handle = g.as_graph_element( + ops.prepend_name_scope( + variable_def.variable_name, import_scope=import_scope)) + self._shape = tensor_shape.TensorShape( + self._handle.op.get_attr("shape")) + self._handle_name = self._handle.name + self._unique_id = self._handle_name + self._initializer_op = g.as_graph_element( + ops.prepend_name_scope( + variable_def.initializer_name, import_scope=import_scope)) + # Check whether initial_value_name exists for backwards compatibility. + if (hasattr(variable_def, "initial_value_name") and + variable_def.initial_value_name): + self._initial_value = g.as_graph_element( + ops.prepend_name_scope(variable_def.initial_value_name, + import_scope=import_scope)) + else: + self._initial_value = None + synchronization, aggregation, trainable = ( + variables.validate_synchronization_aggregation_trainable( + variable_def.synchronization, + variable_def.aggregation, + variable_def.trainable, + variable_def.variable_name)) + self._synchronization = synchronization + self._aggregation = aggregation + self._trainable = trainable + if variable_def.snapshot_name: + snapshot = g.as_graph_element( + ops.prepend_name_scope( + variable_def.snapshot_name, import_scope=import_scope)) + if snapshot.op.type != "ReadVariableOp": + self._cached_value = snapshot + else: + self._cached_value = None + while snapshot.op.type != "ReadVariableOp": + snapshot = snapshot.op.inputs[0] + self._graph_element = snapshot + else: + self._cached_value = None + # Legacy case for protos without the snapshot name; assume it's the + # following. + self._graph_element = g.get_tensor_by_name( + self._handle.op.name + "/Read/ReadVariableOp:0") + if variable_def.HasField("save_slice_info_def"): + self._save_slice_info = variables.Variable.SaveSliceInfo( + save_slice_info_def=variable_def.save_slice_info_def, + import_scope=import_scope) + else: + self._save_slice_info = None + self._caching_device = None + self._dtype = dtypes.as_dtype(self._handle.op.get_attr("dtype")) + self._constraint = None + + +class UninitializedVariable(BaseResourceVariable): + """A variable with no initializer.""" + + def __init__( # pylint: disable=super-init-not-called + self, + trainable=None, + caching_device=None, + name=None, + shape=None, + dtype=None, + constraint=None, + synchronization=None, + aggregation=None, + extra_handle_data=None, + distribute_strategy=None, + **unused_kwargs): + """Creates the variable handle. + + Args: + trainable: If `True`, GradientTapes automatically watch uses of this + Variable. + caching_device: Optional device string or function describing where the + Variable should be cached for reading. Defaults to the Variable's + device. If not `None`, caches on another device. Typical use is to + cache on the device where the Ops using the Variable reside, to + deduplicate copying through `Switch` and other conditional statements. + name: Optional name for the variable. Defaults to `'Variable'` and gets + uniquified automatically. + shape: The variable's shape. + dtype: The variable's dtype. + constraint: An optional projection function to be applied to the variable + after being updated by an `Optimizer` (e.g. used to implement norm + constraints or value constraints for layer weights). The function must + take as input the unprojected Tensor representing the value of the + variable and return the Tensor for the projected value + (which must have the same shape). Constraints are not safe to + use when doing asynchronous distributed training. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + `tf.VariableSynchronization`. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + `tf.VariableAggregation`. + extra_handle_data: Optional, another resource handle or Tensor with handle + data to merge with `shape` and `dtype`. + distribute_strategy: The tf.distribute.Strategy this variable is being + created inside of. + """ + with ops.init_scope(): + self._in_graph_mode = not context.executing_eagerly() + with ops.init_scope(): + with ops.name_scope(name, "Variable") as name: + handle_name = ops.name_from_scope_name(name) + if self._in_graph_mode: + shared_name = handle_name + unique_id = shared_name + else: + unique_id = "%s_%d" % (handle_name, ops.uid()) + shared_name = context.shared_name(unique_id) + handle = variable_handle_from_shape_and_dtype( + shape=shape, dtype=dtype, shared_name=shared_name, + name=name, graph_mode=self._in_graph_mode, + extra_handle_data=extra_handle_data) + if not context.executing_eagerly(): + with ops.name_scope("Read"), ops.colocate_with(handle): + # Manually assign reads to the handle's device to avoid log + # messages. + with ops.device(handle.device): + value = gen_resource_variable_ops.read_variable_op(handle, dtype) + _maybe_set_handle_data(dtype, handle, value) + graph_element = value + ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) + # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable, + # because retraining or frozen use of imported SavedModels is + # controlled at higher levels of model building. + else: + graph_element = None + super(UninitializedVariable, self).__init__( + distribute_strategy=distribute_strategy, shape=shape, dtype=dtype, + unique_id=unique_id, handle_name=handle_name, constraint=constraint, + handle=handle, graph_element=graph_element, trainable=trainable, + synchronization=synchronization, aggregation=aggregation) + + pywrap_tensorflow.TFE_Py_RegisterResourceVariableType(ResourceVariable) math_ops._resource_variable_type = ResourceVariable # pylint: disable=protected-access @@ -1596,48 +1777,39 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. -ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor) -ops.register_dense_tensor_like_type(ResourceVariable) +ops.register_tensor_conversion_function(BaseResourceVariable, + _dense_var_to_tensor) +ops.register_dense_tensor_like_type(BaseResourceVariable) -class _UnreadVariable(ResourceVariable): +class _UnreadVariable(BaseResourceVariable): """Represents a future for a read of a variable. Pretends to be the tensor if anyone looks. """ - def __init__(self, handle, dtype, # pylint: disable=super-init-not-called - shape, in_graph_mode, deleter, parent_op, unique_id): - # We do not call super init on purpose. - self._trainable = False - self._synchronization = None - self._aggregation = None - self._save_slice_info = None - self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - self._in_graph_mode = in_graph_mode - self._handle = handle - self._shape = shape - self._initial_value = None - if isinstance(self._handle, ops.EagerTensor): - self._handle_name = "" + def __init__(self, handle, dtype, shape, in_graph_mode, deleter, + parent_op, unique_id): + if isinstance(handle, ops.EagerTensor): + handle_name = "" else: - self._handle_name = self._handle.name - self._unique_id = unique_id - self._dtype = dtype - self._constraint = None - self._cached_value = None - self._is_initialized_op = None - self._initializer_op = None - self._parent_op = parent_op + handle_name = handle.name # Only create a graph_element if we're in session.run-land as only # session.run requires a preexisting tensor to evaluate. Otherwise we can # avoid accidentally reading the variable. if (context.executing_eagerly() or ops.get_default_graph()._building_function): # pylint: disable=protected-access - self._graph_element = None + graph_element = None else: - self._graph_element = self.read_value() - self._handle_deleter = deleter + with ops.control_dependencies([parent_op]): + graph_element = gen_resource_variable_ops.read_variable_op( + handle, dtype) + _maybe_set_handle_data(dtype, handle, graph_element) + super(_UnreadVariable, self).__init__( + handle=handle, shape=shape, handle_name=handle_name, + unique_id=unique_id, dtype=dtype, handle_deleter=deleter, + graph_element=graph_element) + self._parent_op = parent_op @property def name(self): @@ -1676,9 +1848,9 @@ def _ReadGrad(_, grad): def variable_shape(handle, out_type=dtypes.int32): if getattr( - handle, "_handle_data", None) is None or not handle._handle_data.is_set: + handle, "_handle_data", None) is None or not handle._handle_data.is_set: # pylint: disable=protected-access return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) - shape_proto = handle._handle_data.shape_and_type[0].shape + shape_proto = handle._handle_data.shape_and_type[0].shape # pylint: disable=protected-access if shape_proto.unknown_rank or any(x.size == -1 for x in shape_proto.dim): return gen_resource_variable_ops.variable_shape(handle, out_type=out_type) return constant_op.constant([x.size for x in shape_proto.dim], dtype=out_type) @@ -1749,121 +1921,10 @@ ops.register_proto_function( def is_resource_variable(var): """"Returns True if `var` is to be considered a ResourceVariable.""" - return isinstance(var, ResourceVariable) or hasattr( + return isinstance(var, BaseResourceVariable) or hasattr( var, "_should_act_as_resource_variable") -# TODO(allenl): Rather than UninitializedVariable inheriting from -# ResourceVariable, ResourceVariable should inherit from UninitializedVariable -# and add its initialization logic. -class UninitializedVariable(ResourceVariable): - """A variable with no initializer.""" - - def __init__( # pylint: disable=super-init-not-called - self, - trainable=None, - caching_device=None, - name=None, - shape=None, - dtype=None, - constraint=None, - synchronization=None, - aggregation=None, - extra_handle_data=None, - distribute_strategy=None, - **unused_kwargs): - """Creates the variable handle. - - Args: - trainable: If `True`, GradientTapes automatically watch uses of this - Variable. - caching_device: Optional device string or function describing where the - Variable should be cached for reading. Defaults to the Variable's - device. If not `None`, caches on another device. Typical use is to - cache on the device where the Ops using the Variable reside, to - deduplicate copying through `Switch` and other conditional statements. - name: Optional name for the variable. Defaults to `'Variable'` and gets - uniquified automatically. - shape: The variable's shape. - dtype: The variable's dtype. - constraint: An optional projection function to be applied to the variable - after being updated by an `Optimizer` (e.g. used to implement norm - constraints or value constraints for layer weights). The function must - take as input the unprojected Tensor representing the value of the - variable and return the Tensor for the projected value - (which must have the same shape). Constraints are not safe to - use when doing asynchronous distributed training. - synchronization: Indicates when a distributed a variable will be - aggregated. Accepted values are constants defined in the class - `tf.VariableSynchronization`. By default the synchronization is set to - `AUTO` and the current `DistributionStrategy` chooses - when to synchronize. - aggregation: Indicates how a distributed variable will be aggregated. - Accepted values are constants defined in the class - `tf.VariableAggregation`. - extra_handle_data: Optional, another resource handle or Tensor with handle - data to merge with `shape` and `dtype`. - distribute_strategy: The tf.distribute.Strategy this variable is being - created inside of. - """ - with ops.init_scope(): - self._in_graph_mode = not context.executing_eagerly() - synchronization, aggregation, trainable = ( - variables.validate_synchronization_aggregation_trainable( - synchronization, aggregation, trainable, name)) - self._trainable = trainable - self._synchronization = synchronization - self._aggregation = aggregation - self._save_slice_info = None - self._initial_value = None - self._initializer_op = None - self._is_initialized_op = None - self._graph_element = None - self._cached_value = None - self._distribute_strategy = distribute_strategy - # Store the graph key so optimizers know how to only retrieve variables from - # this graph. Guaranteed to be the same as the eager graph_key. - self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access - self._shape = tensor_shape.as_shape(shape) - self._dtype = dtypes.as_dtype(dtype) - with ops.init_scope(): - with ops.name_scope(name, "Variable") as name: - handle_name = ops.name_from_scope_name(name) - if self._in_graph_mode: - shared_name = handle_name - unique_id = shared_name - else: - unique_id = "%s_%d" % (handle_name, ops.uid()) - shared_name = context.shared_name(unique_id) - self._handle = variable_handle_from_shape_and_dtype( - shape=shape, dtype=dtype, shared_name=shared_name, - name=name, graph_mode=self._in_graph_mode, - extra_handle_data=extra_handle_data) - if self._in_graph_mode: - with ops.name_scope("Read"), ops.colocate_with(self._handle): - # Manually assign reads to the handle's device to avoid log - # messages. - with ops.device(self._handle.device): - value = self._read_variable_op() - self._graph_element = value - ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) - # Do *not* add to TRAINABLE_VARIABLES here, even if self._trainable, - # because retraining or frozen use of imported SavedModels is - # controlled at higher levels of model building. - self._unique_id = unique_id - self._handle_name = handle_name + ":0" - self._constraint = constraint - # After the handle has been created, set up a way to clean it up when - # executing eagerly. We'll hold the only reference to the deleter, so that - # when this object is garbage collected the deleter will be too. This - # means ResourceVariables can be part of reference cycles without those - # cycles being uncollectable. - if not self._in_graph_mode: - self._handle_deleter = EagerResourceDeleter( - handle=self._handle, handle_device=self._handle.device) - self._cached_shape_as_list = None - - def copy_to_graph_uninitialized(var): """Copies an existing variable to a new graph, with no initializer.""" # Like ResourceVariable.__deepcopy__, but does not set an initializer on the diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py index 1993babfb3c..1dffb9e03c8 100644 --- a/tensorflow/python/saved_model/function_deserialization.py +++ b/tensorflow/python/saved_model/function_deserialization.py @@ -38,7 +38,7 @@ from tensorflow.python.util import tf_inspect def _is_tensor(t): - return isinstance(t, (ops.Tensor, resource_variable_ops.ResourceVariable)) + return isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable)) def _call_concrete_function(function, inputs): diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 1fd84b8293c..8e6ce221825 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -604,9 +604,9 @@ class Optimizer( # We colocate all ops created in _apply_dense or _apply_sparse # on the same device as the variable. # TODO(apassos): figure out how to get the variable name here. - if context.executing_eagerly() or isinstance( - var, - resource_variable_ops.ResourceVariable) and not var._in_graph_mode: # pylint: disable=protected-access + if (context.executing_eagerly() or + isinstance(var, resource_variable_ops.BaseResourceVariable) + and not var._in_graph_mode): # pylint: disable=protected-access scope_name = "" else: scope_name = var.op.name @@ -617,7 +617,8 @@ class Optimizer( else: with ops.control_dependencies([self._finish(update_ops, "update")]): with ops.colocate_with(global_step): - if isinstance(global_step, resource_variable_ops.ResourceVariable): + if isinstance( + global_step, resource_variable_ops.BaseResourceVariable): # TODO(apassos): the implicit read in assign_add is slow; consider # making it less so. apply_updates = resource_variable_ops.assign_add_variable_op( diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index 06a4b7ba4a9..81d3d4d0031 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -179,7 +179,7 @@ def saveable_objects_for_op(op, name): # pylint: enable=protected-access else: # A variable or tensor. - if isinstance(op, resource_variable_ops.ResourceVariable): + if isinstance(op, resource_variable_ops.BaseResourceVariable): # pylint: disable=protected-access if op._in_graph_mode: variable = op._graph_element @@ -233,7 +233,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True): # pylint: disable=protected-access for var in op_list: resource_or_ref_variable = ( - isinstance(var, resource_variable_ops.ResourceVariable) or + isinstance(var, resource_variable_ops.BaseResourceVariable) or isinstance(var, variables.RefVariable)) if isinstance(var, saveable_object.SaveableObject): @@ -263,7 +263,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True): # indicating whether they were created in a graph building context. We # also get Tensors when graph building, which do not have this property. if not getattr(var, "_in_graph_mode", True): - if not isinstance(var, resource_variable_ops.ResourceVariable): + if not isinstance(var, resource_variable_ops.BaseResourceVariable): raise ValueError( "Can only save/restore ResourceVariables when eager execution " "is enabled, type: %s." % type(var)) @@ -277,7 +277,7 @@ def op_list_to_dict(op_list, convert_variable_to_tensor=True): (var._shared_name,)) else: if convert_variable_to_tensor: - if isinstance(var, resource_variable_ops.ResourceVariable): + if isinstance(var, resource_variable_ops.BaseResourceVariable): var = var._graph_element # pylint: disable=protected-access else: var = ops.internal_convert_to_tensor(var, as_ref=True)