Change TF 2.0 tf.summary.SummaryWriter to map 1:1 to underlying resource
This changes the resource model for the TF 2.0 tf.summary.SummaryWriter as returned by tf.summary.create_file_writer(). Previously, we preserved the 1.x tf.contrib.summary.create_file_writer() semantics where underlying resources were shared by logdir to facilitate reuse of a single eventfile per directory, which avoids accidentally overwriting an existing eventfile (which is possible if both are created within the same second) and is better supported by TensorBoard. Unfortunately, that sharing behavior resulted in buggy and confusing resource deletion behavior as described in #25707, and is also out of alignment with the overall TF 2.0 philosophy that resource lifetimes should match the lifetimes of their encapsulating Python objects. The new SummaryWriter wrapper object fixes this by generating a unique resource name for each instance. Inside @tf.functions, this requires lifting initialization out using init_scope(), which means in-graph tensors can no longer be passed to create_file_writer(), and the returned SummaryWriter must be stored in a non-local variable so that it outlives the tracing phase and is still alive when the function graph executes. A future change will add better error messaging for the latter circumstance. To keep the SummaryWriter <-> resource mapping 1:1, this also changes the close() method so that in eager mode it will permanently close the SummaryWriter and prevent re-initialization. (A future change may do this for graph mode as well.) To mitigate the risk of eventfile name collisions, the new wrapper bakes the PID and a TF-generated UID into the filename suffix. Finally, this cleans up the public API surface by 1) moving the "logdir=None means no-op" behavior into a dedicated create_noop_writer() factory function and 2) hiding SummaryWriter.__init__() from the public API by making SummaryWriter an abstract base class; instances should be obtained via factory functions anyway. Fixes #25707. A subsequent change will update the TF 2.0 Keras callback to call create_file_writer_v2(), which should address #25524 and #24632. PiperOrigin-RevId: 235249883
This commit is contained in:
parent
25642bfb04
commit
826027dbd4
tensorflow
python
tools/api/golden/v2
@ -3303,6 +3303,7 @@ py_library(
|
||||
":smart_cond",
|
||||
":summary_op_util",
|
||||
":summary_ops_gen",
|
||||
":tensor_util",
|
||||
":training_util",
|
||||
":util",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
|
@ -1093,6 +1093,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:lib",
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
@ -32,6 +31,7 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
@ -52,7 +52,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
def testWrite(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
output = summary_ops.write('tag', 42, step=12)
|
||||
self.assertTrue(output.numpy())
|
||||
events = events_from_logdir(logdir)
|
||||
@ -64,11 +64,12 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testWrite_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function
|
||||
def f():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
return summary_ops.write('tag', 42, step=12)
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
@def_function.function
|
||||
def f():
|
||||
with writer.as_default():
|
||||
return summary_ops.write('tag', 42, step=12)
|
||||
output = f()
|
||||
self.assertTrue(output.numpy())
|
||||
events = events_from_logdir(logdir)
|
||||
@ -83,7 +84,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
metadata = summary_pb2.SummaryMetadata()
|
||||
metadata.plugin_data.plugin_name = 'foo'
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
summary_ops.write('obj', 0, 0, metadata=metadata)
|
||||
summary_ops.write('bytes', 0, 0, metadata=metadata.SerializeToString())
|
||||
m = constant_op.constant(metadata.SerializeToString())
|
||||
@ -104,7 +105,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
def testWrite_ndarray(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
summary_ops.write('tag', [[1, 2], [3, 4]], step=12)
|
||||
events = events_from_logdir(logdir)
|
||||
value = events[1].summary.value[0]
|
||||
@ -114,7 +115,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
t = constant_op.constant([[1, 2], [3, 4]])
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
summary_ops.write('tag', t, step=12)
|
||||
expected = t.numpy()
|
||||
events = events_from_logdir(logdir)
|
||||
@ -123,11 +124,12 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testWrite_tensor_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function
|
||||
def f(t):
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
summary_ops.write('tag', t, step=12)
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
@def_function.function
|
||||
def f(t):
|
||||
with writer.as_default():
|
||||
summary_ops.write('tag', t, step=12)
|
||||
t = constant_op.constant([[1, 2], [3, 4]])
|
||||
f(t)
|
||||
expected = t.numpy()
|
||||
@ -138,7 +140,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
def testWrite_stringTensor(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
summary_ops.write('tag', [b'foo', b'bar'], step=12)
|
||||
events = events_from_logdir(logdir)
|
||||
value = events[1].summary.value[0]
|
||||
@ -168,7 +170,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
def testWrite_recordIf_constant(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
self.assertTrue(summary_ops.write('default', 1, step=0))
|
||||
with summary_ops.record_if(True):
|
||||
self.assertTrue(summary_ops.write('set_on', 1, step=0))
|
||||
@ -181,16 +183,17 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testWrite_recordIf_constant_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function
|
||||
def f():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
# Use assertAllEqual instead of assertTrue since it works in a defun.
|
||||
self.assertAllEqual(summary_ops.write('default', 1, step=0), True)
|
||||
with summary_ops.record_if(True):
|
||||
self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True)
|
||||
with summary_ops.record_if(False):
|
||||
self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False)
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
@def_function.function
|
||||
def f():
|
||||
with writer.as_default():
|
||||
# Use assertAllEqual instead of assertTrue since it works in a defun.
|
||||
self.assertAllEqual(summary_ops.write('default', 1, step=0), True)
|
||||
with summary_ops.record_if(True):
|
||||
self.assertAllEqual(summary_ops.write('set_on', 1, step=0), True)
|
||||
with summary_ops.record_if(False):
|
||||
self.assertAllEqual(summary_ops.write('set_off', 1, step=0), False)
|
||||
f()
|
||||
events = events_from_logdir(logdir)
|
||||
self.assertEqual(3, len(events))
|
||||
@ -204,7 +207,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
def record_fn():
|
||||
step.assign_add(1)
|
||||
return int(step % 2) == 0
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
with summary_ops.record_if(record_fn):
|
||||
self.assertTrue(summary_ops.write('tag', 1, step=step))
|
||||
self.assertFalse(summary_ops.write('tag', 1, step=step))
|
||||
@ -220,6 +223,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
def testWrite_recordIf_callable_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
step = variables.Variable(-1, dtype=dtypes.int64)
|
||||
@def_function.function
|
||||
def record_fn():
|
||||
@ -227,7 +231,7 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
return math_ops.equal(step % 2, 0)
|
||||
@def_function.function
|
||||
def f():
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with writer.as_default():
|
||||
with summary_ops.record_if(record_fn):
|
||||
return [
|
||||
summary_ops.write('tag', 1, step=step),
|
||||
@ -243,13 +247,14 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testWrite_recordIf_tensorInput_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)])
|
||||
def f(step):
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.record_if(math_ops.equal(step % 2, 0)):
|
||||
return summary_ops.write('tag', 1, step=step)
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=[], dtype=dtypes.int64)])
|
||||
def f(step):
|
||||
with writer.as_default():
|
||||
with summary_ops.record_if(math_ops.equal(step % 2, 0)):
|
||||
return summary_ops.write('tag', 1, step=step)
|
||||
self.assertTrue(f(0))
|
||||
self.assertFalse(f(1))
|
||||
self.assertTrue(f(2))
|
||||
@ -311,77 +316,152 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
|
||||
class SummaryWriterTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testWriterInitAndClose(self):
|
||||
def testCreate_withInitAndClose(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer(
|
||||
writer = summary_ops.create_file_writer_v2(
|
||||
logdir, max_queue=1000, flush_millis=1000000)
|
||||
files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(1, len(files))
|
||||
file1 = files[0]
|
||||
self.assertEqual(1, len(events_from_file(file1))) # file_version Event
|
||||
get_total = lambda: len(events_from_logdir(logdir))
|
||||
self.assertEqual(1, get_total()) # file_version Event
|
||||
# Calling init() again while writer is open has no effect
|
||||
writer.init()
|
||||
self.assertEqual(1, len(events_from_file(file1)))
|
||||
self.assertEqual(1, get_total())
|
||||
with writer.as_default():
|
||||
summary_ops.write('tag', 1, step=0)
|
||||
self.assertEqual(1, len(events_from_file(file1)))
|
||||
self.assertEqual(1, get_total())
|
||||
# Calling .close() should do an implicit flush
|
||||
writer.close()
|
||||
self.assertEqual(2, len(events_from_file(file1)))
|
||||
# Calling init() on a closed writer should start a new file
|
||||
time.sleep(1.1) # Ensure filename has a different timestamp
|
||||
writer.init()
|
||||
files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(2, len(files))
|
||||
files.remove(file1)
|
||||
file2 = files[0]
|
||||
self.assertEqual(1, len(events_from_file(file2))) # file_version
|
||||
self.assertEqual(2, len(events_from_file(file1))) # should be unchanged
|
||||
self.assertEqual(2, get_total())
|
||||
|
||||
def testSharedName(self):
|
||||
def testCreate_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function
|
||||
def f():
|
||||
# Returned SummaryWriter must be stored in a non-local variable so it
|
||||
# lives throughout the function execution.
|
||||
if not hasattr(f, 'writer'):
|
||||
f.writer = summary_ops.create_file_writer_v2(logdir)
|
||||
with context.eager_mode():
|
||||
f()
|
||||
event_files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(1, len(event_files))
|
||||
|
||||
def testCreate_graphTensorArgument_raisesError(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.graph_mode():
|
||||
logdir_tensor = constant_op.constant(logdir)
|
||||
with context.eager_mode():
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Invalid graph Tensor argument.*logdir'):
|
||||
summary_ops.create_file_writer_v2(logdir_tensor)
|
||||
self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
|
||||
|
||||
def testCreate_fromFunction_graphTensorArgument_raisesError(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function
|
||||
def f():
|
||||
summary_ops.create_file_writer_v2(constant_op.constant(logdir))
|
||||
with context.eager_mode():
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'Invalid graph Tensor argument.*logdir'):
|
||||
f()
|
||||
self.assertEmpty(gfile.Glob(os.path.join(logdir, '*')))
|
||||
|
||||
def testCreate_fromFunction_unpersistedResource_raisesError(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function
|
||||
def f():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
pass # Calling .as_default() is enough to indicate use.
|
||||
with context.eager_mode():
|
||||
# TODO(nickfelt): change this to a better error
|
||||
with self.assertRaisesRegex(
|
||||
errors.NotFoundError, 'Resource.*does not exist'):
|
||||
f()
|
||||
# Even though we didn't use it, an event file will have been created.
|
||||
self.assertEqual(1, len(gfile.Glob(os.path.join(logdir, '*'))))
|
||||
|
||||
def testNoSharing(self):
|
||||
# Two writers with the same logdir should not share state.
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
# Create with default shared name (should match logdir)
|
||||
writer1 = summary_ops.create_file_writer(logdir)
|
||||
writer1 = summary_ops.create_file_writer_v2(logdir)
|
||||
with writer1.as_default():
|
||||
summary_ops.write('tag', 1, step=1)
|
||||
summary_ops.flush()
|
||||
# Create with explicit logdir shared name (should be same resource/file)
|
||||
shared_name = 'logdir:' + logdir
|
||||
writer2 = summary_ops.create_file_writer(logdir, name=shared_name)
|
||||
event_files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(1, len(event_files))
|
||||
file1 = event_files[0]
|
||||
|
||||
writer2 = summary_ops.create_file_writer_v2(logdir)
|
||||
with writer2.as_default():
|
||||
summary_ops.write('tag', 1, step=2)
|
||||
summary_ops.flush()
|
||||
# Create with different shared name (should be separate resource/file)
|
||||
time.sleep(1.1) # Ensure filename has a different timestamp
|
||||
writer3 = summary_ops.create_file_writer(logdir, name='other')
|
||||
with writer3.as_default():
|
||||
summary_ops.write('tag', 1, step=3)
|
||||
summary_ops.flush()
|
||||
event_files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(2, len(event_files))
|
||||
event_files.remove(file1)
|
||||
file2 = event_files[0]
|
||||
|
||||
event_files = iter(sorted(gfile.Glob(os.path.join(logdir, '*'))))
|
||||
# Extra writes to ensure interleaved usage works.
|
||||
with writer1.as_default():
|
||||
summary_ops.write('tag', 1, step=1)
|
||||
with writer2.as_default():
|
||||
summary_ops.write('tag', 1, step=2)
|
||||
|
||||
# First file has tags "one" and "two"
|
||||
events = iter(events_from_file(next(event_files)))
|
||||
events = iter(events_from_file(file1))
|
||||
self.assertEqual('brain.Event:2', next(events).file_version)
|
||||
self.assertEqual(1, next(events).step)
|
||||
self.assertEqual(1, next(events).step)
|
||||
self.assertRaises(StopIteration, lambda: next(events))
|
||||
events = iter(events_from_file(file2))
|
||||
self.assertEqual('brain.Event:2', next(events).file_version)
|
||||
self.assertEqual(2, next(events).step)
|
||||
self.assertEqual(2, next(events).step)
|
||||
self.assertRaises(StopIteration, lambda: next(events))
|
||||
|
||||
# Second file has tag "three"
|
||||
events = iter(events_from_file(next(event_files)))
|
||||
self.assertEqual('brain.Event:2', next(events).file_version)
|
||||
self.assertEqual(3, next(events).step)
|
||||
self.assertRaises(StopIteration, lambda: next(events))
|
||||
def testNoSharing_fromFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
@def_function.function
|
||||
def f1():
|
||||
if not hasattr(f1, 'writer'):
|
||||
f1.writer = summary_ops.create_file_writer_v2(logdir)
|
||||
with f1.writer.as_default():
|
||||
summary_ops.write('tag', 1, step=1)
|
||||
@def_function.function
|
||||
def f2():
|
||||
if not hasattr(f2, 'writer'):
|
||||
f2.writer = summary_ops.create_file_writer_v2(logdir)
|
||||
with f2.writer.as_default():
|
||||
summary_ops.write('tag', 1, step=2)
|
||||
with context.eager_mode():
|
||||
f1()
|
||||
event_files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(1, len(event_files))
|
||||
file1 = event_files[0]
|
||||
|
||||
# No more files
|
||||
self.assertRaises(StopIteration, lambda: next(event_files))
|
||||
f2()
|
||||
event_files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(2, len(event_files))
|
||||
event_files.remove(file1)
|
||||
file2 = event_files[0]
|
||||
|
||||
# Extra writes to ensure interleaved usage works.
|
||||
f1()
|
||||
f2()
|
||||
|
||||
events = iter(events_from_file(file1))
|
||||
self.assertEqual('brain.Event:2', next(events).file_version)
|
||||
self.assertEqual(1, next(events).step)
|
||||
self.assertEqual(1, next(events).step)
|
||||
self.assertRaises(StopIteration, lambda: next(events))
|
||||
events = iter(events_from_file(file2))
|
||||
self.assertEqual('brain.Event:2', next(events).file_version)
|
||||
self.assertEqual(2, next(events).step)
|
||||
self.assertEqual(2, next(events).step)
|
||||
self.assertRaises(StopIteration, lambda: next(events))
|
||||
|
||||
def testMaxQueue(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
with summary_ops.create_file_writer(
|
||||
with summary_ops.create_file_writer_v2(
|
||||
logdir, max_queue=1, flush_millis=999999).as_default():
|
||||
get_total = lambda: len(events_from_logdir(logdir))
|
||||
# Note: First tf.Event is always file_version.
|
||||
@ -396,7 +476,7 @@ class SummaryWriterTest(test_util.TensorFlowTestCase):
|
||||
logdir = self.get_temp_dir()
|
||||
get_total = lambda: len(events_from_logdir(logdir))
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer(
|
||||
writer = summary_ops.create_file_writer_v2(
|
||||
logdir, max_queue=1000, flush_millis=1000000)
|
||||
self.assertEqual(1, get_total()) # file_version Event
|
||||
with writer.as_default():
|
||||
@ -412,7 +492,7 @@ class SummaryWriterTest(test_util.TensorFlowTestCase):
|
||||
def testFlushFunction(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer(
|
||||
writer = summary_ops.create_file_writer_v2(
|
||||
logdir, max_queue=999999, flush_millis=999999)
|
||||
with writer.as_default(), summary_ops.always_record_summaries():
|
||||
get_total = lambda: len(events_from_logdir(logdir))
|
||||
@ -436,9 +516,24 @@ class SummaryWriterTest(test_util.TensorFlowTestCase):
|
||||
@test_util.assert_no_new_pyobjects_executing_eagerly
|
||||
def testEagerMemory(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.create_file_writer_v2(logdir).as_default():
|
||||
summary_ops.write('tag', 1, step=0)
|
||||
|
||||
def testClose_preventsLaterUse(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
writer.close()
|
||||
writer.close() # redundant close() is a no-op
|
||||
writer.flush() # redundant flush() is a no-op
|
||||
with self.assertRaisesRegex(RuntimeError, 'already closed'):
|
||||
writer.init()
|
||||
with self.assertRaisesRegex(RuntimeError, 'already closed'):
|
||||
with writer.as_default():
|
||||
self.fail('should not get here')
|
||||
with self.assertRaisesRegex(RuntimeError, 'already closed'):
|
||||
writer.set_as_default()
|
||||
|
||||
def testClose_closesOpenFile(self):
|
||||
try:
|
||||
import psutil # pylint: disable=g-import-not-at-top
|
||||
@ -448,7 +543,7 @@ class SummaryWriterTest(test_util.TensorFlowTestCase):
|
||||
get_open_filenames = lambda: set(info[0] for info in proc.open_files())
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer(logdir)
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(1, len(files))
|
||||
eventfile = files[0]
|
||||
@ -465,7 +560,7 @@ class SummaryWriterTest(test_util.TensorFlowTestCase):
|
||||
get_open_filenames = lambda: set(info[0] for info in proc.open_files())
|
||||
logdir = self.get_temp_dir()
|
||||
with context.eager_mode():
|
||||
writer = summary_ops.create_file_writer(logdir)
|
||||
writer = summary_ops.create_file_writer_v2(logdir)
|
||||
files = gfile.Glob(os.path.join(logdir, '*'))
|
||||
self.assertEqual(1, len(files))
|
||||
eventfile = files[0]
|
||||
|
@ -19,6 +19,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import functools
|
||||
import getpass
|
||||
import os
|
||||
import re
|
||||
@ -34,6 +36,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import smart_cond
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_summary_ops
|
||||
@ -124,65 +127,120 @@ def never_record_summaries():
|
||||
|
||||
|
||||
@tf_export("summary.SummaryWriter", v1=[])
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class SummaryWriter(object):
|
||||
"""Encapsulates a stateful summary writer resource.
|
||||
"""Interface representing a stateful summary writer object."""
|
||||
|
||||
See also:
|
||||
- `tf.summary.create_file_writer`
|
||||
- `tf.summary.create_db_writer`
|
||||
"""
|
||||
@abc.abstractmethod
|
||||
def set_as_default(self):
|
||||
"""Enables this summary writer for the current thread."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __init__(self, resource, init_op_fn):
|
||||
self._resource = resource
|
||||
# TODO(nickfelt): cache constructed ops in graph mode
|
||||
@abc.abstractmethod
|
||||
@tf_contextlib.contextmanager
|
||||
def as_default(self):
|
||||
"""Returns a context manager that enables summary writing."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def init(self):
|
||||
"""Initializes the summary writer."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def flush(self):
|
||||
"""Flushes any buffered data."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def close(self):
|
||||
"""Flushes and closes the summary writer."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class ResourceSummaryWriter(SummaryWriter):
|
||||
"""Implementation of SummaryWriter using a SummaryWriterInterface resource."""
|
||||
|
||||
def __init__(self, shared_name, init_op_fn, name=None, v2=False):
|
||||
self._resource = gen_summary_ops.summary_writer(
|
||||
shared_name=shared_name, name=name)
|
||||
# TODO(nickfelt): cache other constructed ops in graph mode
|
||||
self._init_op_fn = init_op_fn
|
||||
if context.executing_eagerly() and self._resource is not None:
|
||||
self._init_op = init_op_fn(self._resource)
|
||||
self._v2 = v2
|
||||
self._closed = False
|
||||
if context.executing_eagerly():
|
||||
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
|
||||
handle=self._resource, handle_device="cpu:0")
|
||||
else:
|
||||
global _SUMMARY_WRITER_INIT_OP
|
||||
key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
|
||||
_SUMMARY_WRITER_INIT_OP.setdefault(key, []).append(self._init_op)
|
||||
|
||||
def set_as_default(self):
|
||||
"""Enables this summary writer for the current thread."""
|
||||
if self._v2 and context.executing_eagerly() and self._closed:
|
||||
raise RuntimeError("SummaryWriter is already closed")
|
||||
context.context().summary_writer_resource = self._resource
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def as_default(self):
|
||||
"""Enables summary writing within a `with` block."""
|
||||
if self._resource is None:
|
||||
"""Returns a context manager that enables summary writing."""
|
||||
if self._v2 and context.executing_eagerly() and self._closed:
|
||||
raise RuntimeError("SummaryWriter is already closed")
|
||||
old = context.context().summary_writer_resource
|
||||
try:
|
||||
context.context().summary_writer_resource = self._resource
|
||||
yield self
|
||||
else:
|
||||
old = context.context().summary_writer_resource
|
||||
try:
|
||||
context.context().summary_writer_resource = self._resource
|
||||
yield self
|
||||
# Flushes the summary writer in eager mode or in graph functions, but
|
||||
# not in legacy graph mode (you're on your own there).
|
||||
with ops.device("cpu:0"):
|
||||
gen_summary_ops.flush_summary_writer(self._resource)
|
||||
finally:
|
||||
context.context().summary_writer_resource = old
|
||||
# Flushes the summary writer in eager mode or in graph functions, but
|
||||
# not in legacy graph mode (you're on your own there).
|
||||
self.flush()
|
||||
finally:
|
||||
context.context().summary_writer_resource = old
|
||||
|
||||
def init(self):
|
||||
"""Operation to initialize the summary writer resource."""
|
||||
if self._resource is not None:
|
||||
return self._init_op_fn()
|
||||
|
||||
def _flush(self):
|
||||
return _flush_fn(writer=self)
|
||||
"""Initializes the summary writer."""
|
||||
if self._v2:
|
||||
if context.executing_eagerly() and self._closed:
|
||||
raise RuntimeError("SummaryWriter is already closed")
|
||||
return self._init_op
|
||||
# Legacy behavior allows re-initializing the resource.
|
||||
return self._init_op_fn(self._resource)
|
||||
|
||||
def flush(self):
|
||||
"""Operation to force the summary writer to flush any buffered data."""
|
||||
if self._resource is not None:
|
||||
return self._flush()
|
||||
|
||||
def _close(self):
|
||||
with ops.control_dependencies([self.flush()]):
|
||||
with ops.device("cpu:0"):
|
||||
return gen_summary_ops.close_summary_writer(self._resource)
|
||||
"""Flushes any buffered data."""
|
||||
if self._v2 and context.executing_eagerly() and self._closed:
|
||||
return
|
||||
return _flush_fn(writer=self)
|
||||
|
||||
def close(self):
|
||||
"""Operation to flush and close the summary writer resource."""
|
||||
if self._resource is not None:
|
||||
return self._close()
|
||||
"""Flushes and closes the summary writer."""
|
||||
if self._v2 and context.executing_eagerly() and self._closed:
|
||||
return
|
||||
try:
|
||||
with ops.control_dependencies([self.flush()]):
|
||||
with ops.device("cpu:0"):
|
||||
return gen_summary_ops.close_summary_writer(self._resource)
|
||||
finally:
|
||||
if self._v2 and context.executing_eagerly():
|
||||
self._closed = True
|
||||
|
||||
|
||||
class NoopSummaryWriter(SummaryWriter):
|
||||
"""A summary writer that does nothing, for create_noop_writer()."""
|
||||
|
||||
def set_as_default(self):
|
||||
pass
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def as_default(self):
|
||||
yield
|
||||
|
||||
def init(self):
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
|
||||
@tf_export(v1=["summary.initialize"])
|
||||
@ -228,6 +286,66 @@ def initialize(
|
||||
|
||||
|
||||
@tf_export("summary.create_file_writer", v1=[])
|
||||
def create_file_writer_v2(logdir,
|
||||
max_queue=None,
|
||||
flush_millis=None,
|
||||
filename_suffix=None,
|
||||
name=None):
|
||||
"""Creates a summary file writer for the given log directory.
|
||||
|
||||
Args:
|
||||
logdir: a string specifying the directory in which to write an event file.
|
||||
max_queue: the largest number of summaries to keep in a queue; will
|
||||
flush once the queue gets bigger than this. Defaults to 10.
|
||||
flush_millis: the largest interval between flushes. Defaults to 120,000.
|
||||
filename_suffix: optional suffix for the event file name. Defaults to `.v2`.
|
||||
name: a name for the op that creates the writer.
|
||||
|
||||
Returns:
|
||||
A SummaryWriter object.
|
||||
"""
|
||||
if logdir is None:
|
||||
raise ValueError("logdir cannot be None")
|
||||
inside_function = ops.inside_function()
|
||||
with ops.name_scope(name, "create_file_writer") as scope, ops.device("cpu:0"):
|
||||
# Run init inside an init_scope() to hoist it out of tf.functions.
|
||||
with ops.init_scope():
|
||||
if context.executing_eagerly():
|
||||
_check_create_file_writer_args(
|
||||
inside_function,
|
||||
logdir=logdir,
|
||||
max_queue=max_queue,
|
||||
flush_millis=flush_millis,
|
||||
filename_suffix=filename_suffix)
|
||||
logdir = ops.convert_to_tensor(logdir, dtype=dtypes.string)
|
||||
if max_queue is None:
|
||||
max_queue = constant_op.constant(10)
|
||||
if flush_millis is None:
|
||||
flush_millis = constant_op.constant(2 * 60 * 1000)
|
||||
if filename_suffix is None:
|
||||
filename_suffix = constant_op.constant(".v2")
|
||||
# Prepend the PID and a process-local UID to the filename suffix to avoid
|
||||
# filename collisions within the machine (the filename already contains
|
||||
# the hostname to avoid cross-machine collisions).
|
||||
unique_prefix = constant_op.constant(".%s.%s" % (os.getpid(), ops.uid()))
|
||||
filename_suffix = unique_prefix + filename_suffix
|
||||
# Use a unique shared_name to prevent resource sharing.
|
||||
if context.executing_eagerly():
|
||||
shared_name = context.shared_name()
|
||||
else:
|
||||
shared_name = ops._name_from_scope_name(scope) # pylint: disable=protected-access
|
||||
return ResourceSummaryWriter(
|
||||
shared_name=shared_name,
|
||||
init_op_fn=functools.partial(
|
||||
gen_summary_ops.create_summary_file_writer,
|
||||
logdir=logdir,
|
||||
max_queue=max_queue,
|
||||
flush_millis=flush_millis,
|
||||
filename_suffix=filename_suffix),
|
||||
name=name,
|
||||
v2=True)
|
||||
|
||||
|
||||
def create_file_writer(logdir,
|
||||
max_queue=None,
|
||||
flush_millis=None,
|
||||
@ -254,7 +372,7 @@ def create_file_writer(logdir,
|
||||
summary writer.
|
||||
"""
|
||||
if logdir is None:
|
||||
return SummaryWriter(None, None)
|
||||
return NoopSummaryWriter()
|
||||
logdir = str(logdir)
|
||||
with ops.device("cpu:0"):
|
||||
if max_queue is None:
|
||||
@ -265,13 +383,14 @@ def create_file_writer(logdir,
|
||||
filename_suffix = constant_op.constant(".v2")
|
||||
if name is None:
|
||||
name = "logdir:" + logdir
|
||||
return _make_summary_writer(
|
||||
name,
|
||||
gen_summary_ops.create_summary_file_writer,
|
||||
logdir=logdir,
|
||||
max_queue=max_queue,
|
||||
flush_millis=flush_millis,
|
||||
filename_suffix=filename_suffix)
|
||||
return ResourceSummaryWriter(
|
||||
shared_name=name,
|
||||
init_op_fn=functools.partial(
|
||||
gen_summary_ops.create_summary_file_writer,
|
||||
logdir=logdir,
|
||||
max_queue=max_queue,
|
||||
flush_millis=flush_millis,
|
||||
filename_suffix=filename_suffix))
|
||||
|
||||
|
||||
def create_db_writer(db_uri,
|
||||
@ -316,26 +435,23 @@ def create_db_writer(db_uri,
|
||||
"experiment_name", _EXPERIMENT_NAME_PATTERNS, experiment_name)
|
||||
run_name = _cleanse_string("run_name", _RUN_NAME_PATTERNS, run_name)
|
||||
user_name = _cleanse_string("user_name", _USER_NAME_PATTERNS, user_name)
|
||||
return _make_summary_writer(
|
||||
name,
|
||||
gen_summary_ops.create_summary_db_writer,
|
||||
db_uri=db_uri,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
user_name=user_name)
|
||||
return ResourceSummaryWriter(
|
||||
shared_name=name,
|
||||
init_op_fn=functools.partial(
|
||||
gen_summary_ops.create_summary_db_writer,
|
||||
db_uri=db_uri,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
user_name=user_name))
|
||||
|
||||
|
||||
def _make_summary_writer(name, factory, **kwargs):
|
||||
resource = gen_summary_ops.summary_writer(shared_name=name)
|
||||
init_op_fn = lambda: factory(resource, **kwargs)
|
||||
init_op = init_op_fn()
|
||||
if not context.executing_eagerly():
|
||||
# TODO(apassos): Consider doing this instead.
|
||||
# ops.get_default_session().run(init_op)
|
||||
global _SUMMARY_WRITER_INIT_OP
|
||||
key = ops.get_default_graph()._graph_key # pylint: disable=protected-access
|
||||
_SUMMARY_WRITER_INIT_OP.setdefault(key, []).append(init_op)
|
||||
return SummaryWriter(resource, init_op_fn)
|
||||
@tf_export("summary.create_noop_writer", v1=[])
|
||||
def create_noop_writer():
|
||||
"""Returns a summary writer that does nothing.
|
||||
|
||||
This is useful as a placeholder in code that expects a context manager.
|
||||
"""
|
||||
return NoopSummaryWriter()
|
||||
|
||||
|
||||
def _cleanse_string(name, pattern, value):
|
||||
@ -436,7 +552,7 @@ def write(tag, tensor, step, metadata=None, name=None):
|
||||
tag: string tag used to identify the summary (e.g. in TensorBoard), usually
|
||||
generated with `tf.summary.summary_scope`
|
||||
tensor: the Tensor holding the summary data to write
|
||||
step: `int64`-castable monotic step value for this summary
|
||||
step: `int64`-castable monotonic step value for this summary
|
||||
metadata: Optional SummaryMetadata, as a proto or serialized bytes
|
||||
name: Optional string name for this op.
|
||||
|
||||
@ -734,6 +850,30 @@ def _choose_step(step):
|
||||
return step
|
||||
|
||||
|
||||
def _check_create_file_writer_args(inside_function, **kwargs):
|
||||
"""Helper to check the validity of arguments to a create_file_writer() call.
|
||||
|
||||
Args:
|
||||
inside_function: whether the create_file_writer() call is in a tf.function
|
||||
**kwargs: the arguments to check, as kwargs to give them names.
|
||||
|
||||
Raises:
|
||||
ValueError: if the arguments are graph tensors.
|
||||
"""
|
||||
for arg_name, arg in kwargs.items():
|
||||
if not isinstance(arg, ops.EagerTensor) and tensor_util.is_tensor(arg):
|
||||
if inside_function:
|
||||
raise ValueError(
|
||||
"Invalid graph Tensor argument \"%s=%s\" to create_file_writer() "
|
||||
"inside an @tf.function. The create call will be lifted into the "
|
||||
"outer eager execution context, so it cannot consume graph tensors "
|
||||
"defined inside the function body." % (arg_name, arg))
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid graph Tensor argument \"%s=%s\" to eagerly executed "
|
||||
"create_file_writer()." % (arg_name, arg))
|
||||
|
||||
|
||||
def run_metadata(name, data, step):
|
||||
"""Writes entire RunMetadata summary.
|
||||
|
||||
|
@ -4,7 +4,6 @@ tf_class {
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'resource\', \'init_op_fn\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "as_default"
|
||||
|
@ -12,6 +12,10 @@ tf_module {
|
||||
name: "create_file_writer"
|
||||
argspec: "args=[\'logdir\', \'max_queue\', \'flush_millis\', \'filename_suffix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "create_noop_writer"
|
||||
argspec: "args=[], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "flush"
|
||||
argspec: "args=[\'writer\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user