Disable summary step error when summaries are disabled.
Summary write functions do not complain about steps when summary writer is not configured, but they do complain when summary write is disabled by `tf.summary.record_if(False)`. This change makes the error detection more consistent, and makes the summary function not raise errors of missing step numbers when summary is disabled. PiperOrigin-RevId: 358122924 Change-Id: I2e1b17ddca8e1e6d0b99d0445cd9b5e14be158b0
This commit is contained in:
parent
426558e017
commit
5f775be482
@ -179,6 +179,14 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaisesRegex(ValueError, 'No step set'):
|
||||
summary_ops.write('tag', 42)
|
||||
|
||||
@test_util.also_run_as_tf_function
|
||||
def testWrite_noStep_okayIfNotRecordingSummaries(self):
|
||||
logdir = self.get_temp_dir()
|
||||
with summary_ops.create_file_writer(logdir).as_default():
|
||||
with summary_ops.record_if(False):
|
||||
# Use assertAllEqual instead of assertFalse since it works in a defun.
|
||||
self.assertAllEqual(False, summary_ops.write('tag', 42))
|
||||
|
||||
def testWrite_usingDefaultStep(self):
|
||||
logdir = self.get_temp_dir()
|
||||
try:
|
||||
|
@ -697,9 +697,6 @@ def write(tag, tensor, step=None, metadata=None, name=None):
|
||||
return constant_op.constant(False)
|
||||
if step is None:
|
||||
step = get_step()
|
||||
if step is None:
|
||||
raise ValueError("No step set via 'step' argument or "
|
||||
"tf.summary.experimental.set_step()")
|
||||
if metadata is None:
|
||||
serialized_metadata = b""
|
||||
elif hasattr(metadata, "SerializeToString"):
|
||||
@ -709,6 +706,10 @@ def write(tag, tensor, step=None, metadata=None, name=None):
|
||||
|
||||
def record():
|
||||
"""Record the actual summary and return True."""
|
||||
if step is None:
|
||||
raise ValueError("No step set via 'step' argument or "
|
||||
"tf.summary.experimental.set_step()")
|
||||
|
||||
# Note the identity to move the tensor to the CPU.
|
||||
with ops.device("cpu:0"):
|
||||
summary_tensor = tensor() if callable(tensor) else array_ops.identity(
|
||||
|
Loading…
Reference in New Issue
Block a user