Do not cache reads for ResourceVariables.
Change: 147487962
This commit is contained in:
parent
37fea69023
commit
3e8728dbbc
@ -196,18 +196,14 @@ class ResourceVariable(object):
|
||||
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
|
||||
self._handle, self._initial_value, name=n)
|
||||
with ops.name_scope("Read"), ops.colocate_with(self._handle):
|
||||
self._value = gen_resource_variable_ops.read_variable_op(
|
||||
value = gen_resource_variable_ops.read_variable_op(
|
||||
self._handle, dtype=self._dtype)
|
||||
self._graph_element = value
|
||||
if caching_device is not None:
|
||||
with ops.device(caching_device):
|
||||
self._cached_value = array_ops.identity(self._value)
|
||||
self._cached_value = array_ops.identity(value)
|
||||
else:
|
||||
with ops.colocate_with(self._handle.op):
|
||||
self._cached_value = array_ops.identity(self._value)
|
||||
# TODO(apassos) this is terrible monkey-patching required to make
|
||||
# initialize_all_variables work. Replace self._value with an explicit
|
||||
# class instead of monkey-patching.
|
||||
self._value.initializer = self._initialize_op
|
||||
self._cached_value = None
|
||||
ops.add_to_collections(collections, self)
|
||||
ops.add_to_collections([ops.GraphKeys.RESOURCES], self)
|
||||
|
||||
@ -225,10 +221,12 @@ class ResourceVariable(object):
|
||||
self._initialize_op = g.as_graph_element(
|
||||
ops.prepend_name_scope(variable_def.initializer_name,
|
||||
import_scope=import_scope))
|
||||
self._cached_value = g.as_graph_element(
|
||||
ops.prepend_name_scope(variable_def.snapshot_name,
|
||||
import_scope=import_scope))
|
||||
self._value = self._cached_value
|
||||
if variable_def.snapshot_name:
|
||||
self._cached_value = g.as_graph_element(
|
||||
ops.prepend_name_scope(variable_def.snapshot_name,
|
||||
import_scope=import_scope))
|
||||
else:
|
||||
self._cached_value = None
|
||||
if variable_def.HasField("save_slice_info_def"):
|
||||
self._save_slice_info = variables.Variable.SaveSliceInfo(
|
||||
save_slice_info_def=variable_def.save_slice_info_def)
|
||||
@ -254,7 +252,7 @@ class ResourceVariable(object):
|
||||
|
||||
def get_shape(self):
|
||||
"""The shape of this variable."""
|
||||
return self._value.get_shape()
|
||||
return tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
|
||||
|
||||
@property
|
||||
def create(self):
|
||||
@ -268,11 +266,14 @@ class ResourceVariable(object):
|
||||
|
||||
def value(self):
|
||||
"""A cached operation which reads the value of this variable."""
|
||||
return self._cached_value
|
||||
if self._cached_value is not None:
|
||||
return self._cached_value
|
||||
return gen_resource_variable_ops.read_variable_op(
|
||||
self._handle, dtype=self._dtype)
|
||||
|
||||
def _as_graph_element(self):
|
||||
"""Conversion function for Graph.as_graph_element()."""
|
||||
return self._value
|
||||
return self._graph_element
|
||||
|
||||
@property
|
||||
def initializer(self):
|
||||
@ -286,7 +287,7 @@ class ResourceVariable(object):
|
||||
|
||||
def eval(self, session=None):
|
||||
"""Evaluates and returns the value of this variable."""
|
||||
return self._value.eval(session=session)
|
||||
return self._graph_element.eval(session=session)
|
||||
|
||||
def _set_save_slice_info(self, save_slice_info):
|
||||
"""Sets the slice info for this `ResourceVariable`.
|
||||
@ -339,8 +340,9 @@ class ResourceVariable(object):
|
||||
self.handle.name, export_scope)
|
||||
var_def.initializer_name = ops.strip_name_scope(
|
||||
self.initializer.name, export_scope)
|
||||
var_def.snapshot_name = ops.strip_name_scope(
|
||||
self.value().name, export_scope)
|
||||
if self._cached_value is not None:
|
||||
var_def.snapshot_name = ops.strip_name_scope(
|
||||
self._cached_value.name, export_scope)
|
||||
var_def.is_resource = True
|
||||
if self._save_slice_info:
|
||||
var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
|
||||
@ -420,9 +422,7 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
||||
if dtype is not None and dtype != var.value().dtype:
|
||||
print("trying to switch the dtype to ", dtype, " from ", var.value().dtype)
|
||||
return NotImplemented
|
||||
if as_ref:
|
||||
return var._value
|
||||
return var._cached_value
|
||||
return var.value()
|
||||
# pylint: enable=unused-argument,protected-access
|
||||
|
||||
# Register a conversion function which reads the value of the variable,
|
||||
|
Loading…
Reference in New Issue
Block a user