Add a simple internal utility for collecting post-optimization graphs

I can never remember how to call export_run_metadata and parse its result otherwise.

PiperOrigin-RevId: 279831870
Change-Id: Id8fd2e432d3b680007de6d6e942152a8fe3ccdf7
This commit is contained in:
Allen Lavoie 2019-11-11 15:35:26 -08:00 committed by TensorFlower Gardener
parent 7db012db68
commit 4a0ba5e5b1
2 changed files with 52 additions and 0 deletions

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import contextlib
import copy
import random
import threading
@ -2007,6 +2008,42 @@ def export_run_metadata():
return context().export_run_metadata()
@contextlib.contextmanager
def collect_optimized_graphs():
"""Collects a flat list of post-optimization graphs.
The collected graphs include device placements, which can be useful for
testing.
Usage:
```
@def_function.function
def f(x):
return x + constant_op.constant(1.)
with context.collect_optimized_graphs() as graphs:
with ops.device("CPU:0"):
f(constant_op.constant(1.))
graph, = graphs # `graph` contains a single GraphDef for inspection
```
Yields:
A list of GraphDefs, populated when the context manager exits.
"""
ctx = context()
ctx.enable_graph_collection()
try:
graphs = []
yield graphs
metadata = ctx.export_run_metadata()
finally:
ctx.disable_graph_collection()
for graph in metadata.function_graphs:
graphs.append(graph.post_optimization_graph)
def get_server_def():
return context().get_server_def()

View File

@ -22,6 +22,7 @@ import weakref
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@ -71,6 +72,20 @@ class ContextTest(test.TestCase):
del tensor2
self.assertIs(weak_c(), None)
def testSimpleGraphCollection(self):
@def_function.function
def f(x):
return x + constant_op.constant(1.)
with context.collect_optimized_graphs() as graphs:
with ops.device('CPU:0'):
f(constant_op.constant(1.))
self.assertLen(graphs, 1)
graph, = graphs
self.assertIn('CPU:0', graph.node[0].device)
if __name__ == '__main__':
ops.enable_eager_execution()