From a89be16e7a60e75147c9add575c8191a2ceb9a48 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Mon, 1 Apr 2019 12:50:39 -0700 Subject: [PATCH] Record variable accesses for FuncGraph in more places This fixes tape recording in nested functions, but may cause a performance issue. We can't use tape.variables_accessed in ConcreteFunction._call_flat since we now need to filter out non-trainable variables. PiperOrigin-RevId: 241380881 --- tensorflow/python/eager/function.py | 6 +++--- .../python/ops/resource_variable_ops.py | 20 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 91e00c42b6e..060c87f1e70 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -575,7 +575,8 @@ class ConcreteFunction(object): ctx = context.context() executing_eagerly = ctx.executing_eagerly() - tape.variables_accessed(self._func_graph.variables) + for v in self._func_graph.variables: + resource_variable_ops.variable_accessed(v) tensor_inputs = [] variables_used = set([]) @@ -585,8 +586,7 @@ class ConcreteFunction(object): # pass its handle only once. if arg.handle in variables_used: continue - if arg.trainable: - tape.variable_accessed(arg) + resource_variable_ops.variable_accessed(arg) tensor_inputs.append(arg.handle) variables_used.add(arg.handle) elif isinstance(arg, ops.Tensor): diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 60cc65b14b6..69f4c8af5cf 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -287,6 +287,14 @@ def _maybe_set_handle_data(dtype, handle, tensor): shape_and_type=handle_data.shape_and_type[1:])) +def variable_accessed(variable): + """Records that `variable` was accessed for the tape and FuncGraph.""" + if hasattr(ops.get_default_graph(), "watch_variable"): + ops.get_default_graph().watch_variable(variable) + if variable.trainable: + tape.variable_accessed(variable) + + class ResourceVariable(variables.VariableV1): """Variable based on resource handles. @@ -852,11 +860,7 @@ class ResourceVariable(variables.VariableV1): T=self.dtype) def _read_variable_op(self): - if hasattr(ops.get_default_graph(), "watch_variable"): - ops.get_default_graph().watch_variable(self) - - if self.trainable: - tape.variable_accessed(self) + variable_accessed(self) result = gen_resource_variable_ops.read_variable_op(self._handle, self._dtype) _maybe_set_handle_data(self._dtype, self._handle, result) @@ -888,8 +892,7 @@ class ResourceVariable(variables.VariableV1): def sparse_read(self, indices, name=None): """Reads the value of this variable sparsely, using `gather`.""" with ops.name_scope("Gather" if name is None else name) as name: - if self.trainable: - tape.variable_accessed(self) + variable_accessed(self) value = gen_resource_variable_ops.resource_gather( self._handle, indices, dtype=self._dtype, name=name) @@ -1026,8 +1029,7 @@ class ResourceVariable(variables.VariableV1): return assign_add_op def _lazy_read(self, op): - if self.trainable: - tape.variable_accessed(self) + variable_accessed(self) return _UnreadVariable( handle=self._handle, dtype=self.dtype, shape=self._shape, in_graph_mode=self._in_graph_mode,