From 4ec6f2b07c08ddab479541cad0c61f169c1f816f Mon Sep 17 00:00:00 2001
From: Alexandre Passos <apassos@google.com>
Date: Mon, 23 Oct 2017 09:59:21 -0700
Subject: [PATCH] Switching contrib.summaries API to be context-manager-centric

PiperOrigin-RevId: 173129793
---
 tensorflow/contrib/summary/summary_ops.py     | 33 +++++++-
 .../contrib/summary/summary_ops_test.py       | 79 ++++++++++---------
 tensorflow/python/eager/context.py            |  6 +-
 3 files changed, 77 insertions(+), 41 deletions(-)

diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index ba3619bfc90..30a9398ee54 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
 from tensorflow.python.layers import utils
 from tensorflow.python.ops import summary_op_util
 from tensorflow.python.training import training_util
+from tensorflow.python.util import tf_contextlib
 
 # Name for a collection which is expected to have at most a single boolean
 # Tensor. If this tensor is True the summary ops will record summaries.
@@ -46,22 +47,50 @@ def should_record_summaries():
 
 
 # TODO(apassos) consider how to handle local step here.
+@tf_contextlib.contextmanager
 def record_summaries_every_n_global_steps(n):
   """Sets the should_record_summaries Tensor to true if global_step % n == 0."""
   collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+  old = collection_ref[:]
   collection_ref[:] = [training_util.get_global_step() % n == 0]
+  yield
+  collection_ref[:] = old
 
 
+@tf_contextlib.contextmanager
 def always_record_summaries():
   """Sets the should_record_summaries Tensor to always true."""
   collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+  old = collection_ref[:]
   collection_ref[:] = [True]
+  yield
+  collection_ref[:] = old
 
 
+@tf_contextlib.contextmanager
 def never_record_summaries():
   """Sets the should_record_summaries Tensor to always false."""
   collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
+  old = collection_ref[:]
   collection_ref[:] = [False]
+  yield
+  collection_ref[:] = old
+
+
+class SummaryWriter(object):
+
+  def __init__(self, resource):
+    self._resource = resource
+
+  def set_as_default(self):
+    context.context().summary_writer_resource = self._resource
+
+  @tf_contextlib.contextmanager
+  def as_default(self):
+    old = context.context().summary_writer_resource
+    context.context().summary_writer_resource = self._resource
+    yield
+    context.context().summary_writer_resource = old
 
 
 def create_summary_file_writer(logdir,
@@ -77,9 +106,11 @@ def create_summary_file_writer(logdir,
   if filename_suffix is None:
     filename_suffix = constant_op.constant("")
   resource = gen_summary_ops.summary_writer(shared_name=name)
+  # TODO(apassos) ensure the initialization op runs when in graph mode; consider
+  # calling session.run here.
   gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
                                              flush_secs, filename_suffix)
-  context.context().summary_writer_resource = resource
+  return SummaryWriter(resource)
 
 
 def _nothing():
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index 2cd4fce5b39..405a92a7263 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -41,60 +41,65 @@ class TargetTest(test_util.TensorFlowTestCase):
 
   def testShouldRecordSummary(self):
     self.assertFalse(summary_ops.should_record_summaries())
-    summary_ops.always_record_summaries()
-    self.assertTrue(summary_ops.should_record_summaries())
+    with summary_ops.always_record_summaries():
+      self.assertTrue(summary_ops.should_record_summaries())
 
   def testSummaryOps(self):
     training_util.get_or_create_global_step()
     logdir = tempfile.mkdtemp()
-    summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
-    summary_ops.always_record_summaries()
-    summary_ops.generic('tensor', 1, '')
-    summary_ops.scalar('scalar', 2.0)
-    summary_ops.histogram('histogram', [1.0])
-    summary_ops.image('image', [[[[1.0]]]])
-    summary_ops.audio('audio', [[1.0]], 1.0, 1)
-    # The working condition of the ops is tested in the C++ test so we just
-    # test here that we're calling them correctly.
-    self.assertTrue(gfile.Exists(logdir))
+    with summary_ops.create_summary_file_writer(
+        logdir, max_queue=0,
+        name='t0').as_default(), summary_ops.always_record_summaries():
+      summary_ops.generic('tensor', 1, '')
+      summary_ops.scalar('scalar', 2.0)
+      summary_ops.histogram('histogram', [1.0])
+      summary_ops.image('image', [[[[1.0]]]])
+      summary_ops.audio('audio', [[1.0]], 1.0, 1)
+      # The working condition of the ops is tested in the C++ test so we just
+      # test here that we're calling them correctly.
+      self.assertTrue(gfile.Exists(logdir))
 
   def testDefunSummarys(self):
     training_util.get_or_create_global_step()
     logdir = tempfile.mkdtemp()
-    summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1')
-    summary_ops.always_record_summaries()
+    with summary_ops.create_summary_file_writer(
+        logdir, max_queue=0,
+        name='t1').as_default(), summary_ops.always_record_summaries():
 
-    @function.defun
-    def write():
-      summary_ops.scalar('scalar', 2.0)
+      @function.defun
+      def write():
+        summary_ops.scalar('scalar', 2.0)
 
-    write()
+      write()
 
-    self.assertTrue(gfile.Exists(logdir))
-    files = gfile.ListDirectory(logdir)
-    self.assertEqual(len(files), 1)
-    records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
-    self.assertEqual(len(records), 2)
-    event = event_pb2.Event()
-    event.ParseFromString(records[1])
-    self.assertEqual(event.summary.value[0].simple_value, 2.0)
+      self.assertTrue(gfile.Exists(logdir))
+      files = gfile.ListDirectory(logdir)
+      self.assertEqual(len(files), 1)
+      records = list(
+          tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
+      self.assertEqual(len(records), 2)
+      event = event_pb2.Event()
+      event.ParseFromString(records[1])
+      self.assertEqual(event.summary.value[0].simple_value, 2.0)
 
   def testSummaryName(self):
     training_util.get_or_create_global_step()
     logdir = tempfile.mkdtemp()
-    summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2')
-    summary_ops.always_record_summaries()
+    with summary_ops.create_summary_file_writer(
+        logdir, max_queue=0,
+        name='t2').as_default(), summary_ops.always_record_summaries():
 
-    summary_ops.scalar('scalar', 2.0)
+      summary_ops.scalar('scalar', 2.0)
 
-    self.assertTrue(gfile.Exists(logdir))
-    files = gfile.ListDirectory(logdir)
-    self.assertEqual(len(files), 1)
-    records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
-    self.assertEqual(len(records), 2)
-    event = event_pb2.Event()
-    event.ParseFromString(records[1])
-    self.assertEqual(event.summary.value[0].tag, 'scalar')
+      self.assertTrue(gfile.Exists(logdir))
+      files = gfile.ListDirectory(logdir)
+      self.assertEqual(len(files), 1)
+      records = list(
+          tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
+      self.assertEqual(len(records), 2)
+      event = event_pb2.Event()
+      event.ParseFromString(records[1])
+      self.assertEqual(event.summary.value[0].tag, 'scalar')
 
 
 if __name__ == '__main__':
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index c5eedb7c9cf..92f4e15c054 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -58,6 +58,7 @@ class _EagerContext(threading.local):
     self.mode = _default_mode
     self.scope_name = ""
     self.recording_summaries = False
+    self.summary_writer_resource = None
     self.scalar_cache = {}
 
 
@@ -86,7 +87,6 @@ class Context(object):
     self._eager_context = _EagerContext()
     self._context_handle = None
     self._context_devices = None
-    self._summary_writer_resource = None
     self._post_execution_callbacks = []
     self._config = config
     self._seed = None
@@ -213,12 +213,12 @@ class Context(object):
   @property
   def summary_writer_resource(self):
     """Returns summary writer resource."""
-    return self._summary_writer_resource
+    return self._eager_context.summary_writer_resource
 
   @summary_writer_resource.setter
   def summary_writer_resource(self, resource):
     """Sets summary writer resource."""
-    self._summary_writer_resource = resource
+    self._eager_context.summary_writer_resource = resource
 
   @property
   def device_name(self):