From a5d38592ce513d4a77d27397455ebeb02bc67549 Mon Sep 17 00:00:00 2001 From: Stephan Lee Date: Fri, 5 Feb 2021 10:26:42 -0800 Subject: [PATCH] Change API contract for the private add_function_callback and allow graph edits before finalizing a func_graph. Currently, function_callbacks were invoked after the graphs were finalized, primarily used to prevent dereferencing. While the API seemingly allow graph edits, they are futile as the graph is already committed to the C++. With this change, we change the semantics of the API a little bit and explicitly allows graph edits. PiperOrigin-RevId: 355877980 Change-Id: If43e37d2fb1032085048b7810aa4e387714e2f88 --- .../python/debug/lib/dumping_callback.py | 10 ++++- tensorflow/python/eager/BUILD | 1 + tensorflow/python/eager/function.py | 25 +++++++++---- tensorflow/python/eager/function_test.py | 37 +++++++++++++++++-- 4 files changed, 60 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/debug/lib/dumping_callback.py b/tensorflow/python/debug/lib/dumping_callback.py index 56de65d2339..4e23c0d2539 100644 --- a/tensorflow/python/debug/lib/dumping_callback.py +++ b/tensorflow/python/debug/lib/dumping_callback.py @@ -125,15 +125,21 @@ class _DumpingCallback(object): self._placeholder_to_debug_tensor = dict() self._writer = None - def function_callback(self, function): + def function_callback(self, function, name, graph, inputs, outputs): """A callback to be called on creation of Functions. Used to establish a join between function name and graph (context) ID. Args: function: The just-created Function. + name: Name of the function. + graph: FuncGraph, the graph containing the operations in the function. + inputs: the tensors in the graph to be used as inputs to the function + outputs: the tensors in the graph which will be outputs from the function """ - graph_id = self._get_context_id(function.graph) + del name, inputs, outputs + + graph_id = self._get_context_id(graph) with self._context_lock: # NOTE(cais): We currently store the function (_EagerDefinedFunction) # as keys of this dict, because weakrefs to them sometimes become diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index f765018e426..66b0771f3d1 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -511,6 +511,7 @@ cuda_py_test( "//tensorflow/python:init_ops", "//tensorflow/python:layers", "//tensorflow/python:list_ops", + "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_seed", "//tensorflow/python:resource_variable_ops", diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 2daccff8a89..0b345210d4f 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -315,11 +315,20 @@ def add_function_callback(function_callback): The callback function has the signature: - `def function_callback(function):` + `def function_callback(function, name, graph, inputs, outputs):` - wherein `function` is the just-created _EagerDefinedFunction. - The callback is invoked immediately after a new `_EagerDefinedFunction` - is created. The return value(s) of the callback function (if any) is ignored. + where: + - `function`: _EagerDefinedFunction being created before finalizing the graph. + Do not modify the function directly but instead modify the graph. + - `name`: name of the function. + - `graph`: Graph of the function. + - `inputs`: `tuple` of tensors used as inputs to the function. + - `outputs`: `tuple` of tensors used as outputs from the function. + + The callback is at the top of the `_EagerDefinedFunction` construction, giving + callback an opportunity to make the last edits to the graph. Do not make + changes to `graph, inputs`, and `outputs` manually, but, instead, set the + `graph` as the default then define ops. Repeated registration of the same callback function is idempotent. After a callback is added, it can be removed with the @@ -427,9 +436,12 @@ class _EagerDefinedFunction(object): name: str, the name for the created function. graph: Graph, the graph containing the operations in the function inputs: the tensors in the graph to be used as inputs to the function - outputs: the tensors in the graph which will be outputs to the function + outputs: the tensors in the graph which will be outputs from the function attrs: dict mapping names of attributes to their AttrValue values """ + for function_callback in _function_callbacks: + function_callback(self, name, graph, tuple(inputs), tuple(outputs)) + input_ops = set(arg.op for arg in inputs) operations = [op for op in graph.get_operations() if op not in input_ops] @@ -494,9 +506,6 @@ class _EagerDefinedFunction(object): self.graph = graph self._stateful_ops = tuple(op for op in operations if op._is_stateful) # pylint: disable=protected-access - for function_callback in _function_callbacks: - function_callback(self) - def add_to_graph(self, g=None): """Add the function to the current context or a graph, if supplied. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index acb4464a6a3..5c8ee1459f2 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -22,6 +22,7 @@ import copy import functools import itertools import multiprocessing.pool +import os import sys import time import weakref @@ -70,6 +71,7 @@ from tensorflow.python.ops import gen_sendrecv_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import list_ops +from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops @@ -3448,7 +3450,8 @@ class FunctionTest(test.TestCase, parameterized.TestCase): def testAddFunctionCallback(self): functions = [] - def function_callback(f): + def function_callback(f, name, graph, inputs, outputs): + del name, graph, inputs, outputs functions.append(f) @def_function.function @@ -3471,13 +3474,41 @@ class FunctionTest(test.TestCase, parameterized.TestCase): finally: function.clear_function_callbacks() + def testFunctionCallbackAddOps(self): + file_name = os.path.join(self.get_temp_dir(), 'test') + + def function_callback(f, name, graph, inputs, outputs): + del f, name, inputs + + with graph.as_default(): + printer = logging_ops.print_v2( + 'hello', + output_stream='file://' + file_name + ) + outputs[0].op._add_control_input(printer) + + @def_function.function + def plus_one(x): + return x + 1 + + self.addCleanup(function.clear_function_callbacks) + function.add_function_callback(function_callback) + x_float32 = numpy.array(3.0, dtype=numpy.float32) + + self.assertAllClose(plus_one(x_float32), 4.0) + + with open(file_name, 'r') as f: + self.assertEqual(f.read().strip(), 'hello') + def testRemoveFunctionCallback(self): functions_1 = [] - def function_callback_1(f): + def function_callback_1(f, name, graph, inputs, outputs): + del name, graph, inputs, outputs functions_1.append(f) functions_2 = [] - def function_callback_2(f): + def function_callback_2(f, name, graph, inputs, outputs): + del name, graph, inputs, outputs functions_2.append(f) @def_function.function