Record variable accesses for FuncGraph in more places

This fixes tape recording in nested functions, but may cause a performance issue. We can't use tape.variables_accessed in ConcreteFunction._call_flat since we now need to filter out non-trainable variables.

PiperOrigin-RevId: 241380881
This commit is contained in:
Allen Lavoie 2019-04-01 12:50:39 -07:00 committed by TensorFlower Gardener
parent 27edfec48e
commit a89be16e7a
2 changed files with 14 additions and 12 deletions

View File

@ -575,7 +575,8 @@ class ConcreteFunction(object):
ctx = context.context()
executing_eagerly = ctx.executing_eagerly()
tape.variables_accessed(self._func_graph.variables)
for v in self._func_graph.variables:
resource_variable_ops.variable_accessed(v)
tensor_inputs = []
variables_used = set([])
@ -585,8 +586,7 @@ class ConcreteFunction(object):
# pass its handle only once.
if arg.handle in variables_used:
continue
if arg.trainable:
tape.variable_accessed(arg)
resource_variable_ops.variable_accessed(arg)
tensor_inputs.append(arg.handle)
variables_used.add(arg.handle)
elif isinstance(arg, ops.Tensor):

View File

@ -287,6 +287,14 @@ def _maybe_set_handle_data(dtype, handle, tensor):
shape_and_type=handle_data.shape_and_type[1:]))
def variable_accessed(variable):
"""Records that `variable` was accessed for the tape and FuncGraph."""
if hasattr(ops.get_default_graph(), "watch_variable"):
ops.get_default_graph().watch_variable(variable)
if variable.trainable:
tape.variable_accessed(variable)
class ResourceVariable(variables.VariableV1):
"""Variable based on resource handles.
@ -852,11 +860,7 @@ class ResourceVariable(variables.VariableV1):
T=self.dtype)
def _read_variable_op(self):
if hasattr(ops.get_default_graph(), "watch_variable"):
ops.get_default_graph().watch_variable(self)
if self.trainable:
tape.variable_accessed(self)
variable_accessed(self)
result = gen_resource_variable_ops.read_variable_op(self._handle,
self._dtype)
_maybe_set_handle_data(self._dtype, self._handle, result)
@ -888,8 +892,7 @@ class ResourceVariable(variables.VariableV1):
def sparse_read(self, indices, name=None):
"""Reads the value of this variable sparsely, using `gather`."""
with ops.name_scope("Gather" if name is None else name) as name:
if self.trainable:
tape.variable_accessed(self)
variable_accessed(self)
value = gen_resource_variable_ops.resource_gather(
self._handle, indices, dtype=self._dtype, name=name)
@ -1026,8 +1029,7 @@ class ResourceVariable(variables.VariableV1):
return assign_add_op
def _lazy_read(self, op):
if self.trainable:
tape.variable_accessed(self)
variable_accessed(self)
return _UnreadVariable(
handle=self._handle, dtype=self.dtype, shape=self._shape,
in_graph_mode=self._in_graph_mode,