Switching contrib.summaries API to be context-manager-centric
PiperOrigin-RevId: 173129793
This commit is contained in:
parent
03b02ffc9e
commit
4ec6f2b07c
@ -27,6 +27,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.layers import utils
|
||||
from tensorflow.python.ops import summary_op_util
|
||||
from tensorflow.python.training import training_util
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
|
||||
# Name for a collection which is expected to have at most a single boolean
|
||||
# Tensor. If this tensor is True the summary ops will record summaries.
|
||||
@ -46,22 +47,50 @@ def should_record_summaries():
|
||||
|
||||
|
||||
# TODO(apassos) consider how to handle local step here.
|
||||
@tf_contextlib.contextmanager
|
||||
def record_summaries_every_n_global_steps(n):
|
||||
"""Sets the should_record_summaries Tensor to true if global_step % n == 0."""
|
||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
old = collection_ref[:]
|
||||
collection_ref[:] = [training_util.get_global_step() % n == 0]
|
||||
yield
|
||||
collection_ref[:] = old
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def always_record_summaries():
|
||||
"""Sets the should_record_summaries Tensor to always true."""
|
||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
old = collection_ref[:]
|
||||
collection_ref[:] = [True]
|
||||
yield
|
||||
collection_ref[:] = old
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def never_record_summaries():
|
||||
"""Sets the should_record_summaries Tensor to always false."""
|
||||
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
|
||||
old = collection_ref[:]
|
||||
collection_ref[:] = [False]
|
||||
yield
|
||||
collection_ref[:] = old
|
||||
|
||||
|
||||
class SummaryWriter(object):
|
||||
|
||||
def __init__(self, resource):
|
||||
self._resource = resource
|
||||
|
||||
def set_as_default(self):
|
||||
context.context().summary_writer_resource = self._resource
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def as_default(self):
|
||||
old = context.context().summary_writer_resource
|
||||
context.context().summary_writer_resource = self._resource
|
||||
yield
|
||||
context.context().summary_writer_resource = old
|
||||
|
||||
|
||||
def create_summary_file_writer(logdir,
|
||||
@ -77,9 +106,11 @@ def create_summary_file_writer(logdir,
|
||||
if filename_suffix is None:
|
||||
filename_suffix = constant_op.constant("")
|
||||
resource = gen_summary_ops.summary_writer(shared_name=name)
|
||||
# TODO(apassos) ensure the initialization op runs when in graph mode; consider
|
||||
# calling session.run here.
|
||||
gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue,
|
||||
flush_secs, filename_suffix)
|
||||
context.context().summary_writer_resource = resource
|
||||
return SummaryWriter(resource)
|
||||
|
||||
|
||||
def _nothing():
|
||||
|
@ -41,60 +41,65 @@ class TargetTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testShouldRecordSummary(self):
|
||||
self.assertFalse(summary_ops.should_record_summaries())
|
||||
summary_ops.always_record_summaries()
|
||||
self.assertTrue(summary_ops.should_record_summaries())
|
||||
with summary_ops.always_record_summaries():
|
||||
self.assertTrue(summary_ops.should_record_summaries())
|
||||
|
||||
def testSummaryOps(self):
|
||||
training_util.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0')
|
||||
summary_ops.always_record_summaries()
|
||||
summary_ops.generic('tensor', 1, '')
|
||||
summary_ops.scalar('scalar', 2.0)
|
||||
summary_ops.histogram('histogram', [1.0])
|
||||
summary_ops.image('image', [[[[1.0]]]])
|
||||
summary_ops.audio('audio', [[1.0]], 1.0, 1)
|
||||
# The working condition of the ops is tested in the C++ test so we just
|
||||
# test here that we're calling them correctly.
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
with summary_ops.create_summary_file_writer(
|
||||
logdir, max_queue=0,
|
||||
name='t0').as_default(), summary_ops.always_record_summaries():
|
||||
summary_ops.generic('tensor', 1, '')
|
||||
summary_ops.scalar('scalar', 2.0)
|
||||
summary_ops.histogram('histogram', [1.0])
|
||||
summary_ops.image('image', [[[[1.0]]]])
|
||||
summary_ops.audio('audio', [[1.0]], 1.0, 1)
|
||||
# The working condition of the ops is tested in the C++ test so we just
|
||||
# test here that we're calling them correctly.
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
|
||||
def testDefunSummarys(self):
|
||||
training_util.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1')
|
||||
summary_ops.always_record_summaries()
|
||||
with summary_ops.create_summary_file_writer(
|
||||
logdir, max_queue=0,
|
||||
name='t1').as_default(), summary_ops.always_record_summaries():
|
||||
|
||||
@function.defun
|
||||
def write():
|
||||
summary_ops.scalar('scalar', 2.0)
|
||||
@function.defun
|
||||
def write():
|
||||
summary_ops.scalar('scalar', 2.0)
|
||||
|
||||
write()
|
||||
write()
|
||||
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
files = gfile.ListDirectory(logdir)
|
||||
self.assertEqual(len(files), 1)
|
||||
records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||
self.assertEqual(len(records), 2)
|
||||
event = event_pb2.Event()
|
||||
event.ParseFromString(records[1])
|
||||
self.assertEqual(event.summary.value[0].simple_value, 2.0)
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
files = gfile.ListDirectory(logdir)
|
||||
self.assertEqual(len(files), 1)
|
||||
records = list(
|
||||
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||
self.assertEqual(len(records), 2)
|
||||
event = event_pb2.Event()
|
||||
event.ParseFromString(records[1])
|
||||
self.assertEqual(event.summary.value[0].simple_value, 2.0)
|
||||
|
||||
def testSummaryName(self):
|
||||
training_util.get_or_create_global_step()
|
||||
logdir = tempfile.mkdtemp()
|
||||
summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t2')
|
||||
summary_ops.always_record_summaries()
|
||||
with summary_ops.create_summary_file_writer(
|
||||
logdir, max_queue=0,
|
||||
name='t2').as_default(), summary_ops.always_record_summaries():
|
||||
|
||||
summary_ops.scalar('scalar', 2.0)
|
||||
summary_ops.scalar('scalar', 2.0)
|
||||
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
files = gfile.ListDirectory(logdir)
|
||||
self.assertEqual(len(files), 1)
|
||||
records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||
self.assertEqual(len(records), 2)
|
||||
event = event_pb2.Event()
|
||||
event.ParseFromString(records[1])
|
||||
self.assertEqual(event.summary.value[0].tag, 'scalar')
|
||||
self.assertTrue(gfile.Exists(logdir))
|
||||
files = gfile.ListDirectory(logdir)
|
||||
self.assertEqual(len(files), 1)
|
||||
records = list(
|
||||
tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||
self.assertEqual(len(records), 2)
|
||||
event = event_pb2.Event()
|
||||
event.ParseFromString(records[1])
|
||||
self.assertEqual(event.summary.value[0].tag, 'scalar')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -58,6 +58,7 @@ class _EagerContext(threading.local):
|
||||
self.mode = _default_mode
|
||||
self.scope_name = ""
|
||||
self.recording_summaries = False
|
||||
self.summary_writer_resource = None
|
||||
self.scalar_cache = {}
|
||||
|
||||
|
||||
@ -86,7 +87,6 @@ class Context(object):
|
||||
self._eager_context = _EagerContext()
|
||||
self._context_handle = None
|
||||
self._context_devices = None
|
||||
self._summary_writer_resource = None
|
||||
self._post_execution_callbacks = []
|
||||
self._config = config
|
||||
self._seed = None
|
||||
@ -213,12 +213,12 @@ class Context(object):
|
||||
@property
|
||||
def summary_writer_resource(self):
|
||||
"""Returns summary writer resource."""
|
||||
return self._summary_writer_resource
|
||||
return self._eager_context.summary_writer_resource
|
||||
|
||||
@summary_writer_resource.setter
|
||||
def summary_writer_resource(self, resource):
|
||||
"""Sets summary writer resource."""
|
||||
self._summary_writer_resource = resource
|
||||
self._eager_context.summary_writer_resource = resource
|
||||
|
||||
@property
|
||||
def device_name(self):
|
||||
|
Loading…
Reference in New Issue
Block a user