From 4a0ba5e5b189813b679f7e943587ea5e0234a6d8 Mon Sep 17 00:00:00 2001
From: Allen Lavoie <allenl@google.com>
Date: Mon, 11 Nov 2019 15:35:26 -0800
Subject: [PATCH] Add a simple internal utility for collecting
 post-optimization graphs

I can never remember how to call export_run_metadata and parse its result otherwise.

PiperOrigin-RevId: 279831870
Change-Id: Id8fd2e432d3b680007de6d6e942152a8fe3ccdf7
---
 tensorflow/python/eager/context.py      | 37 +++++++++++++++++++++++++
 tensorflow/python/eager/context_test.py | 15 ++++++++++
 2 files changed, 52 insertions(+)

diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 8de73bc35d1..93dd0f55e75 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -19,6 +19,7 @@ from __future__ import division
 from __future__ import print_function
 
 import collections
+import contextlib
 import copy
 import random
 import threading
@@ -2007,6 +2008,42 @@ def export_run_metadata():
   return context().export_run_metadata()
 
 
+@contextlib.contextmanager
+def collect_optimized_graphs():
+  """Collects a flat list of post-optimization graphs.
+
+  The collected graphs include device placements, which can be useful for
+  testing.
+
+  Usage:
+
+  ```
+  @def_function.function
+  def f(x):
+    return x + constant_op.constant(1.)
+
+  with context.collect_optimized_graphs() as graphs:
+    with ops.device("CPU:0"):
+      f(constant_op.constant(1.))
+
+  graph, = graphs  # `graph` contains a single GraphDef for inspection
+  ```
+
+  Yields:
+    A list of GraphDefs, populated when the context manager exits.
+  """
+  ctx = context()
+  ctx.enable_graph_collection()
+  try:
+    graphs = []
+    yield graphs
+    metadata = ctx.export_run_metadata()
+  finally:
+    ctx.disable_graph_collection()
+  for graph in metadata.function_graphs:
+    graphs.append(graph.post_optimization_graph)
+
+
 def get_server_def():
   return context().get_server_def()
 
diff --git a/tensorflow/python/eager/context_test.py b/tensorflow/python/eager/context_test.py
index 3b1a3c27622..51738fd8de9 100644
--- a/tensorflow/python/eager/context_test.py
+++ b/tensorflow/python/eager/context_test.py
@@ -22,6 +22,7 @@ import weakref
 import numpy as np
 
 from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
 from tensorflow.python.framework import constant_op
 from tensorflow.python.framework import ops
 from tensorflow.python.platform import test
@@ -71,6 +72,20 @@ class ContextTest(test.TestCase):
     del tensor2
     self.assertIs(weak_c(), None)
 
+  def testSimpleGraphCollection(self):
+
+    @def_function.function
+    def f(x):
+      return x + constant_op.constant(1.)
+
+    with context.collect_optimized_graphs() as graphs:
+      with ops.device('CPU:0'):
+        f(constant_op.constant(1.))
+
+    self.assertLen(graphs, 1)
+    graph, = graphs
+    self.assertIn('CPU:0', graph.node[0].device)
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()