diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 8de73bc35d1..93dd0f55e75 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import contextlib import copy import random import threading @@ -2007,6 +2008,42 @@ def export_run_metadata(): return context().export_run_metadata() +@contextlib.contextmanager +def collect_optimized_graphs(): + """Collects a flat list of post-optimization graphs. + + The collected graphs include device placements, which can be useful for + testing. + + Usage: + + ``` + @def_function.function + def f(x): + return x + constant_op.constant(1.) + + with context.collect_optimized_graphs() as graphs: + with ops.device("CPU:0"): + f(constant_op.constant(1.)) + + graph, = graphs # `graph` contains a single GraphDef for inspection + ``` + + Yields: + A list of GraphDefs, populated when the context manager exits. + """ + ctx = context() + ctx.enable_graph_collection() + try: + graphs = [] + yield graphs + metadata = ctx.export_run_metadata() + finally: + ctx.disable_graph_collection() + for graph in metadata.function_graphs: + graphs.append(graph.post_optimization_graph) + + def get_server_def(): return context().get_server_def() diff --git a/tensorflow/python/eager/context_test.py b/tensorflow/python/eager/context_test.py index 3b1a3c27622..51738fd8de9 100644 --- a/tensorflow/python/eager/context_test.py +++ b/tensorflow/python/eager/context_test.py @@ -22,6 +22,7 @@ import weakref import numpy as np from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -71,6 +72,20 @@ class ContextTest(test.TestCase): del tensor2 self.assertIs(weak_c(), None) + def testSimpleGraphCollection(self): + + @def_function.function + def f(x): + return x + constant_op.constant(1.) + + with context.collect_optimized_graphs() as graphs: + with ops.device('CPU:0'): + f(constant_op.constant(1.)) + + self.assertLen(graphs, 1) + graph, = graphs + self.assertIn('CPU:0', graph.node[0].device) + if __name__ == '__main__': ops.enable_eager_execution()