Do not cache reads for ResourceVariables.

Change: 147487962
This commit is contained in:
Alexandre Passos 2017-02-14 10:16:37 -08:00 committed by TensorFlower Gardener
parent 37fea69023
commit 3e8728dbbc

View File

@ -196,18 +196,14 @@ class ResourceVariable(object):
self._initialize_op = gen_resource_variable_ops.assign_variable_op( self._initialize_op = gen_resource_variable_ops.assign_variable_op(
self._handle, self._initial_value, name=n) self._handle, self._initial_value, name=n)
with ops.name_scope("Read"), ops.colocate_with(self._handle): with ops.name_scope("Read"), ops.colocate_with(self._handle):
self._value = gen_resource_variable_ops.read_variable_op( value = gen_resource_variable_ops.read_variable_op(
self._handle, dtype=self._dtype) self._handle, dtype=self._dtype)
self._graph_element = value
if caching_device is not None: if caching_device is not None:
with ops.device(caching_device): with ops.device(caching_device):
self._cached_value = array_ops.identity(self._value) self._cached_value = array_ops.identity(value)
else: else:
with ops.colocate_with(self._handle.op): self._cached_value = None
self._cached_value = array_ops.identity(self._value)
# TODO(apassos) this is terrible monkey-patching required to make
# initialize_all_variables work. Replace self._value with an explicit
# class instead of monkey-patching.
self._value.initializer = self._initialize_op
ops.add_to_collections(collections, self) ops.add_to_collections(collections, self)
ops.add_to_collections([ops.GraphKeys.RESOURCES], self) ops.add_to_collections([ops.GraphKeys.RESOURCES], self)
@ -225,10 +221,12 @@ class ResourceVariable(object):
self._initialize_op = g.as_graph_element( self._initialize_op = g.as_graph_element(
ops.prepend_name_scope(variable_def.initializer_name, ops.prepend_name_scope(variable_def.initializer_name,
import_scope=import_scope)) import_scope=import_scope))
self._cached_value = g.as_graph_element( if variable_def.snapshot_name:
ops.prepend_name_scope(variable_def.snapshot_name, self._cached_value = g.as_graph_element(
import_scope=import_scope)) ops.prepend_name_scope(variable_def.snapshot_name,
self._value = self._cached_value import_scope=import_scope))
else:
self._cached_value = None
if variable_def.HasField("save_slice_info_def"): if variable_def.HasField("save_slice_info_def"):
self._save_slice_info = variables.Variable.SaveSliceInfo( self._save_slice_info = variables.Variable.SaveSliceInfo(
save_slice_info_def=variable_def.save_slice_info_def) save_slice_info_def=variable_def.save_slice_info_def)
@ -254,7 +252,7 @@ class ResourceVariable(object):
def get_shape(self): def get_shape(self):
"""The shape of this variable.""" """The shape of this variable."""
return self._value.get_shape() return tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
@property @property
def create(self): def create(self):
@ -268,11 +266,14 @@ class ResourceVariable(object):
def value(self): def value(self):
"""A cached operation which reads the value of this variable.""" """A cached operation which reads the value of this variable."""
return self._cached_value if self._cached_value is not None:
return self._cached_value
return gen_resource_variable_ops.read_variable_op(
self._handle, dtype=self._dtype)
def _as_graph_element(self): def _as_graph_element(self):
"""Conversion function for Graph.as_graph_element().""" """Conversion function for Graph.as_graph_element()."""
return self._value return self._graph_element
@property @property
def initializer(self): def initializer(self):
@ -286,7 +287,7 @@ class ResourceVariable(object):
def eval(self, session=None): def eval(self, session=None):
"""Evaluates and returns the value of this variable.""" """Evaluates and returns the value of this variable."""
return self._value.eval(session=session) return self._graph_element.eval(session=session)
def _set_save_slice_info(self, save_slice_info): def _set_save_slice_info(self, save_slice_info):
"""Sets the slice info for this `ResourceVariable`. """Sets the slice info for this `ResourceVariable`.
@ -339,8 +340,9 @@ class ResourceVariable(object):
self.handle.name, export_scope) self.handle.name, export_scope)
var_def.initializer_name = ops.strip_name_scope( var_def.initializer_name = ops.strip_name_scope(
self.initializer.name, export_scope) self.initializer.name, export_scope)
var_def.snapshot_name = ops.strip_name_scope( if self._cached_value is not None:
self.value().name, export_scope) var_def.snapshot_name = ops.strip_name_scope(
self._cached_value.name, export_scope)
var_def.is_resource = True var_def.is_resource = True
if self._save_slice_info: if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto( var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
@ -420,9 +422,7 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
if dtype is not None and dtype != var.value().dtype: if dtype is not None and dtype != var.value().dtype:
print("trying to switch the dtype to ", dtype, " from ", var.value().dtype) print("trying to switch the dtype to ", dtype, " from ", var.value().dtype)
return NotImplemented return NotImplemented
if as_ref: return var.value()
return var._value
return var._cached_value
# pylint: enable=unused-argument,protected-access # pylint: enable=unused-argument,protected-access
# Register a conversion function which reads the value of the variable, # Register a conversion function which reads the value of the variable,