Record variable accesses for FuncGraph in more places
This fixes tape recording in nested functions, but may cause a performance issue. We can't use tape.variables_accessed in ConcreteFunction._call_flat since we now need to filter out non-trainable variables. PiperOrigin-RevId: 241380881
This commit is contained in:
parent
27edfec48e
commit
a89be16e7a
@ -575,7 +575,8 @@ class ConcreteFunction(object):
|
|||||||
ctx = context.context()
|
ctx = context.context()
|
||||||
executing_eagerly = ctx.executing_eagerly()
|
executing_eagerly = ctx.executing_eagerly()
|
||||||
|
|
||||||
tape.variables_accessed(self._func_graph.variables)
|
for v in self._func_graph.variables:
|
||||||
|
resource_variable_ops.variable_accessed(v)
|
||||||
|
|
||||||
tensor_inputs = []
|
tensor_inputs = []
|
||||||
variables_used = set([])
|
variables_used = set([])
|
||||||
@ -585,8 +586,7 @@ class ConcreteFunction(object):
|
|||||||
# pass its handle only once.
|
# pass its handle only once.
|
||||||
if arg.handle in variables_used:
|
if arg.handle in variables_used:
|
||||||
continue
|
continue
|
||||||
if arg.trainable:
|
resource_variable_ops.variable_accessed(arg)
|
||||||
tape.variable_accessed(arg)
|
|
||||||
tensor_inputs.append(arg.handle)
|
tensor_inputs.append(arg.handle)
|
||||||
variables_used.add(arg.handle)
|
variables_used.add(arg.handle)
|
||||||
elif isinstance(arg, ops.Tensor):
|
elif isinstance(arg, ops.Tensor):
|
||||||
|
@ -287,6 +287,14 @@ def _maybe_set_handle_data(dtype, handle, tensor):
|
|||||||
shape_and_type=handle_data.shape_and_type[1:]))
|
shape_and_type=handle_data.shape_and_type[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def variable_accessed(variable):
|
||||||
|
"""Records that `variable` was accessed for the tape and FuncGraph."""
|
||||||
|
if hasattr(ops.get_default_graph(), "watch_variable"):
|
||||||
|
ops.get_default_graph().watch_variable(variable)
|
||||||
|
if variable.trainable:
|
||||||
|
tape.variable_accessed(variable)
|
||||||
|
|
||||||
|
|
||||||
class ResourceVariable(variables.VariableV1):
|
class ResourceVariable(variables.VariableV1):
|
||||||
"""Variable based on resource handles.
|
"""Variable based on resource handles.
|
||||||
|
|
||||||
@ -852,11 +860,7 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
T=self.dtype)
|
T=self.dtype)
|
||||||
|
|
||||||
def _read_variable_op(self):
|
def _read_variable_op(self):
|
||||||
if hasattr(ops.get_default_graph(), "watch_variable"):
|
variable_accessed(self)
|
||||||
ops.get_default_graph().watch_variable(self)
|
|
||||||
|
|
||||||
if self.trainable:
|
|
||||||
tape.variable_accessed(self)
|
|
||||||
result = gen_resource_variable_ops.read_variable_op(self._handle,
|
result = gen_resource_variable_ops.read_variable_op(self._handle,
|
||||||
self._dtype)
|
self._dtype)
|
||||||
_maybe_set_handle_data(self._dtype, self._handle, result)
|
_maybe_set_handle_data(self._dtype, self._handle, result)
|
||||||
@ -888,8 +892,7 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
def sparse_read(self, indices, name=None):
|
def sparse_read(self, indices, name=None):
|
||||||
"""Reads the value of this variable sparsely, using `gather`."""
|
"""Reads the value of this variable sparsely, using `gather`."""
|
||||||
with ops.name_scope("Gather" if name is None else name) as name:
|
with ops.name_scope("Gather" if name is None else name) as name:
|
||||||
if self.trainable:
|
variable_accessed(self)
|
||||||
tape.variable_accessed(self)
|
|
||||||
value = gen_resource_variable_ops.resource_gather(
|
value = gen_resource_variable_ops.resource_gather(
|
||||||
self._handle, indices, dtype=self._dtype, name=name)
|
self._handle, indices, dtype=self._dtype, name=name)
|
||||||
|
|
||||||
@ -1026,8 +1029,7 @@ class ResourceVariable(variables.VariableV1):
|
|||||||
return assign_add_op
|
return assign_add_op
|
||||||
|
|
||||||
def _lazy_read(self, op):
|
def _lazy_read(self, op):
|
||||||
if self.trainable:
|
variable_accessed(self)
|
||||||
tape.variable_accessed(self)
|
|
||||||
return _UnreadVariable(
|
return _UnreadVariable(
|
||||||
handle=self._handle, dtype=self.dtype, shape=self._shape,
|
handle=self._handle, dtype=self.dtype, shape=self._shape,
|
||||||
in_graph_mode=self._in_graph_mode,
|
in_graph_mode=self._in_graph_mode,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user