Do not cache reads for ResourceVariables.
Change: 147487962
This commit is contained in:
parent
37fea69023
commit
3e8728dbbc
@ -196,18 +196,14 @@ class ResourceVariable(object):
|
|||||||
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
|
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
|
||||||
self._handle, self._initial_value, name=n)
|
self._handle, self._initial_value, name=n)
|
||||||
with ops.name_scope("Read"), ops.colocate_with(self._handle):
|
with ops.name_scope("Read"), ops.colocate_with(self._handle):
|
||||||
self._value = gen_resource_variable_ops.read_variable_op(
|
value = gen_resource_variable_ops.read_variable_op(
|
||||||
self._handle, dtype=self._dtype)
|
self._handle, dtype=self._dtype)
|
||||||
|
self._graph_element = value
|
||||||
if caching_device is not None:
|
if caching_device is not None:
|
||||||
with ops.device(caching_device):
|
with ops.device(caching_device):
|
||||||
self._cached_value = array_ops.identity(self._value)
|
self._cached_value = array_ops.identity(value)
|
||||||
else:
|
else:
|
||||||
with ops.colocate_with(self._handle.op):
|
self._cached_value = None
|
||||||
self._cached_value = array_ops.identity(self._value)
|
|
||||||
# TODO(apassos) this is terrible monkey-patching required to make
|
|
||||||
# initialize_all_variables work. Replace self._value with an explicit
|
|
||||||
# class instead of monkey-patching.
|
|
||||||
self._value.initializer = self._initialize_op
|
|
||||||
ops.add_to_collections(collections, self)
|
ops.add_to_collections(collections, self)
|
||||||
ops.add_to_collections([ops.GraphKeys.RESOURCES], self)
|
ops.add_to_collections([ops.GraphKeys.RESOURCES], self)
|
||||||
|
|
||||||
@ -225,10 +221,12 @@ class ResourceVariable(object):
|
|||||||
self._initialize_op = g.as_graph_element(
|
self._initialize_op = g.as_graph_element(
|
||||||
ops.prepend_name_scope(variable_def.initializer_name,
|
ops.prepend_name_scope(variable_def.initializer_name,
|
||||||
import_scope=import_scope))
|
import_scope=import_scope))
|
||||||
self._cached_value = g.as_graph_element(
|
if variable_def.snapshot_name:
|
||||||
ops.prepend_name_scope(variable_def.snapshot_name,
|
self._cached_value = g.as_graph_element(
|
||||||
import_scope=import_scope))
|
ops.prepend_name_scope(variable_def.snapshot_name,
|
||||||
self._value = self._cached_value
|
import_scope=import_scope))
|
||||||
|
else:
|
||||||
|
self._cached_value = None
|
||||||
if variable_def.HasField("save_slice_info_def"):
|
if variable_def.HasField("save_slice_info_def"):
|
||||||
self._save_slice_info = variables.Variable.SaveSliceInfo(
|
self._save_slice_info = variables.Variable.SaveSliceInfo(
|
||||||
save_slice_info_def=variable_def.save_slice_info_def)
|
save_slice_info_def=variable_def.save_slice_info_def)
|
||||||
@ -254,7 +252,7 @@ class ResourceVariable(object):
|
|||||||
|
|
||||||
def get_shape(self):
|
def get_shape(self):
|
||||||
"""The shape of this variable."""
|
"""The shape of this variable."""
|
||||||
return self._value.get_shape()
|
return tensor_shape.TensorShape(self._handle.op.get_attr("shape"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def create(self):
|
def create(self):
|
||||||
@ -268,11 +266,14 @@ class ResourceVariable(object):
|
|||||||
|
|
||||||
def value(self):
|
def value(self):
|
||||||
"""A cached operation which reads the value of this variable."""
|
"""A cached operation which reads the value of this variable."""
|
||||||
return self._cached_value
|
if self._cached_value is not None:
|
||||||
|
return self._cached_value
|
||||||
|
return gen_resource_variable_ops.read_variable_op(
|
||||||
|
self._handle, dtype=self._dtype)
|
||||||
|
|
||||||
def _as_graph_element(self):
|
def _as_graph_element(self):
|
||||||
"""Conversion function for Graph.as_graph_element()."""
|
"""Conversion function for Graph.as_graph_element()."""
|
||||||
return self._value
|
return self._graph_element
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def initializer(self):
|
def initializer(self):
|
||||||
@ -286,7 +287,7 @@ class ResourceVariable(object):
|
|||||||
|
|
||||||
def eval(self, session=None):
|
def eval(self, session=None):
|
||||||
"""Evaluates and returns the value of this variable."""
|
"""Evaluates and returns the value of this variable."""
|
||||||
return self._value.eval(session=session)
|
return self._graph_element.eval(session=session)
|
||||||
|
|
||||||
def _set_save_slice_info(self, save_slice_info):
|
def _set_save_slice_info(self, save_slice_info):
|
||||||
"""Sets the slice info for this `ResourceVariable`.
|
"""Sets the slice info for this `ResourceVariable`.
|
||||||
@ -339,8 +340,9 @@ class ResourceVariable(object):
|
|||||||
self.handle.name, export_scope)
|
self.handle.name, export_scope)
|
||||||
var_def.initializer_name = ops.strip_name_scope(
|
var_def.initializer_name = ops.strip_name_scope(
|
||||||
self.initializer.name, export_scope)
|
self.initializer.name, export_scope)
|
||||||
var_def.snapshot_name = ops.strip_name_scope(
|
if self._cached_value is not None:
|
||||||
self.value().name, export_scope)
|
var_def.snapshot_name = ops.strip_name_scope(
|
||||||
|
self._cached_value.name, export_scope)
|
||||||
var_def.is_resource = True
|
var_def.is_resource = True
|
||||||
if self._save_slice_info:
|
if self._save_slice_info:
|
||||||
var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
|
var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
|
||||||
@ -420,9 +422,7 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False):
|
|||||||
if dtype is not None and dtype != var.value().dtype:
|
if dtype is not None and dtype != var.value().dtype:
|
||||||
print("trying to switch the dtype to ", dtype, " from ", var.value().dtype)
|
print("trying to switch the dtype to ", dtype, " from ", var.value().dtype)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
if as_ref:
|
return var.value()
|
||||||
return var._value
|
|
||||||
return var._cached_value
|
|
||||||
# pylint: enable=unused-argument,protected-access
|
# pylint: enable=unused-argument,protected-access
|
||||||
|
|
||||||
# Register a conversion function which reads the value of the variable,
|
# Register a conversion function which reads the value of the variable,
|
||||||
|
Loading…
Reference in New Issue
Block a user