Change API contract for the private add_function_callback and allow graph edits before finalizing a func_graph.

Currently, function_callbacks were invoked after the graphs were finalized, primarily used to prevent dereferencing. While the API seemingly allow graph edits, they are futile as the graph is already committed to the C++.

With this change, we change the semantics of the API a little bit and explicitly allows graph edits.

PiperOrigin-RevId: 355877980
Change-Id: If43e37d2fb1032085048b7810aa4e387714e2f88
This commit is contained in:
Stephan Lee 2021-02-05 10:26:42 -08:00 committed by TensorFlower Gardener
parent adcdafc984
commit a5d38592ce
4 changed files with 60 additions and 13 deletions

View File

@ -125,15 +125,21 @@ class _DumpingCallback(object):
self._placeholder_to_debug_tensor = dict()
self._writer = None
def function_callback(self, function):
def function_callback(self, function, name, graph, inputs, outputs):
"""A callback to be called on creation of Functions.
Used to establish a join between function name and graph (context) ID.
Args:
function: The just-created Function.
name: Name of the function.
graph: FuncGraph, the graph containing the operations in the function.
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs from the function
"""
graph_id = self._get_context_id(function.graph)
del name, inputs, outputs
graph_id = self._get_context_id(graph)
with self._context_lock:
# NOTE(cais): We currently store the function (_EagerDefinedFunction)
# as keys of this dict, because weakrefs to them sometimes become

View File

@ -511,6 +511,7 @@ cuda_py_test(
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
"//tensorflow/python:list_ops",
"//tensorflow/python:logging_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:resource_variable_ops",

View File

@ -315,11 +315,20 @@ def add_function_callback(function_callback):
The callback function has the signature:
`def function_callback(function):`
`def function_callback(function, name, graph, inputs, outputs):`
wherein `function` is the just-created _EagerDefinedFunction.
The callback is invoked immediately after a new `_EagerDefinedFunction`
is created. The return value(s) of the callback function (if any) is ignored.
where:
- `function`: _EagerDefinedFunction being created before finalizing the graph.
Do not modify the function directly but instead modify the graph.
- `name`: name of the function.
- `graph`: Graph of the function.
- `inputs`: `tuple` of tensors used as inputs to the function.
- `outputs`: `tuple` of tensors used as outputs from the function.
The callback is at the top of the `_EagerDefinedFunction` construction, giving
callback an opportunity to make the last edits to the graph. Do not make
changes to `graph, inputs`, and `outputs` manually, but, instead, set the
`graph` as the default then define ops.
Repeated registration of the same callback function is idempotent.
After a callback is added, it can be removed with the
@ -427,9 +436,12 @@ class _EagerDefinedFunction(object):
name: str, the name for the created function.
graph: Graph, the graph containing the operations in the function
inputs: the tensors in the graph to be used as inputs to the function
outputs: the tensors in the graph which will be outputs to the function
outputs: the tensors in the graph which will be outputs from the function
attrs: dict mapping names of attributes to their AttrValue values
"""
for function_callback in _function_callbacks:
function_callback(self, name, graph, tuple(inputs), tuple(outputs))
input_ops = set(arg.op for arg in inputs)
operations = [op for op in graph.get_operations() if op not in input_ops]
@ -494,9 +506,6 @@ class _EagerDefinedFunction(object):
self.graph = graph
self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access
for function_callback in _function_callbacks:
function_callback(self)
def add_to_graph(self, g=None):
"""Add the function to the current context or a graph, if supplied.

View File

@ -22,6 +22,7 @@ import copy
import functools
import itertools
import multiprocessing.pool
import os
import sys
import time
import weakref
@ -70,6 +71,7 @@ from tensorflow.python.ops import gen_sendrecv_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
@ -3448,7 +3450,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
def testAddFunctionCallback(self):
functions = []
def function_callback(f):
def function_callback(f, name, graph, inputs, outputs):
del name, graph, inputs, outputs
functions.append(f)
@def_function.function
@ -3471,13 +3474,41 @@ class FunctionTest(test.TestCase, parameterized.TestCase):
finally:
function.clear_function_callbacks()
def testFunctionCallbackAddOps(self):
file_name = os.path.join(self.get_temp_dir(), 'test')
def function_callback(f, name, graph, inputs, outputs):
del f, name, inputs
with graph.as_default():
printer = logging_ops.print_v2(
'hello',
output_stream='file://' + file_name
)
outputs[0].op._add_control_input(printer)
@def_function.function
def plus_one(x):
return x + 1
self.addCleanup(function.clear_function_callbacks)
function.add_function_callback(function_callback)
x_float32 = numpy.array(3.0, dtype=numpy.float32)
self.assertAllClose(plus_one(x_float32), 4.0)
with open(file_name, 'r') as f:
self.assertEqual(f.read().strip(), 'hello')
def testRemoveFunctionCallback(self):
functions_1 = []
def function_callback_1(f):
def function_callback_1(f, name, graph, inputs, outputs):
del name, graph, inputs, outputs
functions_1.append(f)
functions_2 = []
def function_callback_2(f):
def function_callback_2(f, name, graph, inputs, outputs):
del name, graph, inputs, outputs
functions_2.append(f)
@def_function.function