Add a simple internal utility for collecting post-optimization graphs
I can never remember how to call export_run_metadata and parse its result otherwise. PiperOrigin-RevId: 279831870 Change-Id: Id8fd2e432d3b680007de6d6e942152a8fe3ccdf7
This commit is contained in:
parent
7db012db68
commit
4a0ba5e5b1
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
@ -2007,6 +2008,42 @@ def export_run_metadata():
|
|||||||
return context().export_run_metadata()
|
return context().export_run_metadata()
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def collect_optimized_graphs():
|
||||||
|
"""Collects a flat list of post-optimization graphs.
|
||||||
|
|
||||||
|
The collected graphs include device placements, which can be useful for
|
||||||
|
testing.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
```
|
||||||
|
@def_function.function
|
||||||
|
def f(x):
|
||||||
|
return x + constant_op.constant(1.)
|
||||||
|
|
||||||
|
with context.collect_optimized_graphs() as graphs:
|
||||||
|
with ops.device("CPU:0"):
|
||||||
|
f(constant_op.constant(1.))
|
||||||
|
|
||||||
|
graph, = graphs # `graph` contains a single GraphDef for inspection
|
||||||
|
```
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
A list of GraphDefs, populated when the context manager exits.
|
||||||
|
"""
|
||||||
|
ctx = context()
|
||||||
|
ctx.enable_graph_collection()
|
||||||
|
try:
|
||||||
|
graphs = []
|
||||||
|
yield graphs
|
||||||
|
metadata = ctx.export_run_metadata()
|
||||||
|
finally:
|
||||||
|
ctx.disable_graph_collection()
|
||||||
|
for graph in metadata.function_graphs:
|
||||||
|
graphs.append(graph.post_optimization_graph)
|
||||||
|
|
||||||
|
|
||||||
def get_server_def():
|
def get_server_def():
|
||||||
return context().get_server_def()
|
return context().get_server_def()
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ import weakref
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -71,6 +72,20 @@ class ContextTest(test.TestCase):
|
|||||||
del tensor2
|
del tensor2
|
||||||
self.assertIs(weak_c(), None)
|
self.assertIs(weak_c(), None)
|
||||||
|
|
||||||
|
def testSimpleGraphCollection(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def f(x):
|
||||||
|
return x + constant_op.constant(1.)
|
||||||
|
|
||||||
|
with context.collect_optimized_graphs() as graphs:
|
||||||
|
with ops.device('CPU:0'):
|
||||||
|
f(constant_op.constant(1.))
|
||||||
|
|
||||||
|
self.assertLen(graphs, 1)
|
||||||
|
graph, = graphs
|
||||||
|
self.assertIn('CPU:0', graph.node[0].device)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user