Record variable accesses for FuncGraph in more places
This fixes tape recording in nested functions, but may cause a performance issue. We can't use tape.variables_accessed in ConcreteFunction._call_flat since we now need to filter out non-trainable variables. PiperOrigin-RevId: 241380881
This commit is contained in:
parent
27edfec48e
commit
a89be16e7a
@ -575,7 +575,8 @@ class ConcreteFunction(object):
|
||||
ctx = context.context()
|
||||
executing_eagerly = ctx.executing_eagerly()
|
||||
|
||||
tape.variables_accessed(self._func_graph.variables)
|
||||
for v in self._func_graph.variables:
|
||||
resource_variable_ops.variable_accessed(v)
|
||||
|
||||
tensor_inputs = []
|
||||
variables_used = set([])
|
||||
@ -585,8 +586,7 @@ class ConcreteFunction(object):
|
||||
# pass its handle only once.
|
||||
if arg.handle in variables_used:
|
||||
continue
|
||||
if arg.trainable:
|
||||
tape.variable_accessed(arg)
|
||||
resource_variable_ops.variable_accessed(arg)
|
||||
tensor_inputs.append(arg.handle)
|
||||
variables_used.add(arg.handle)
|
||||
elif isinstance(arg, ops.Tensor):
|
||||
|
@ -287,6 +287,14 @@ def _maybe_set_handle_data(dtype, handle, tensor):
|
||||
shape_and_type=handle_data.shape_and_type[1:]))
|
||||
|
||||
|
||||
def variable_accessed(variable):
|
||||
"""Records that `variable` was accessed for the tape and FuncGraph."""
|
||||
if hasattr(ops.get_default_graph(), "watch_variable"):
|
||||
ops.get_default_graph().watch_variable(variable)
|
||||
if variable.trainable:
|
||||
tape.variable_accessed(variable)
|
||||
|
||||
|
||||
class ResourceVariable(variables.VariableV1):
|
||||
"""Variable based on resource handles.
|
||||
|
||||
@ -852,11 +860,7 @@ class ResourceVariable(variables.VariableV1):
|
||||
T=self.dtype)
|
||||
|
||||
def _read_variable_op(self):
|
||||
if hasattr(ops.get_default_graph(), "watch_variable"):
|
||||
ops.get_default_graph().watch_variable(self)
|
||||
|
||||
if self.trainable:
|
||||
tape.variable_accessed(self)
|
||||
variable_accessed(self)
|
||||
result = gen_resource_variable_ops.read_variable_op(self._handle,
|
||||
self._dtype)
|
||||
_maybe_set_handle_data(self._dtype, self._handle, result)
|
||||
@ -888,8 +892,7 @@ class ResourceVariable(variables.VariableV1):
|
||||
def sparse_read(self, indices, name=None):
|
||||
"""Reads the value of this variable sparsely, using `gather`."""
|
||||
with ops.name_scope("Gather" if name is None else name) as name:
|
||||
if self.trainable:
|
||||
tape.variable_accessed(self)
|
||||
variable_accessed(self)
|
||||
value = gen_resource_variable_ops.resource_gather(
|
||||
self._handle, indices, dtype=self._dtype, name=name)
|
||||
|
||||
@ -1026,8 +1029,7 @@ class ResourceVariable(variables.VariableV1):
|
||||
return assign_add_op
|
||||
|
||||
def _lazy_read(self, op):
|
||||
if self.trainable:
|
||||
tape.variable_accessed(self)
|
||||
variable_accessed(self)
|
||||
return _UnreadVariable(
|
||||
handle=self._handle, dtype=self.dtype, shape=self._shape,
|
||||
in_graph_mode=self._in_graph_mode,
|
||||
|
Loading…
x
Reference in New Issue
Block a user