diff --git a/tensorflow/python/kernel_tests/summary_ops_test.py b/tensorflow/python/kernel_tests/summary_ops_test.py index d5cd9d7bf43..12d2a03ea77 100644 --- a/tensorflow/python/kernel_tests/summary_ops_test.py +++ b/tensorflow/python/kernel_tests/summary_ops_test.py @@ -179,6 +179,14 @@ class SummaryOpsCoreTest(test_util.TensorFlowTestCase): with self.assertRaisesRegex(ValueError, 'No step set'): summary_ops.write('tag', 42) + @test_util.also_run_as_tf_function + def testWrite_noStep_okayIfNotRecordingSummaries(self): + logdir = self.get_temp_dir() + with summary_ops.create_file_writer(logdir).as_default(): + with summary_ops.record_if(False): + # Use assertAllEqual instead of assertFalse since it works in a defun. + self.assertAllEqual(False, summary_ops.write('tag', 42)) + def testWrite_usingDefaultStep(self): logdir = self.get_temp_dir() try: diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py index a334f95b9fe..f1630b6a248 100644 --- a/tensorflow/python/ops/summary_ops_v2.py +++ b/tensorflow/python/ops/summary_ops_v2.py @@ -697,9 +697,6 @@ def write(tag, tensor, step=None, metadata=None, name=None): return constant_op.constant(False) if step is None: step = get_step() - if step is None: - raise ValueError("No step set via 'step' argument or " - "tf.summary.experimental.set_step()") if metadata is None: serialized_metadata = b"" elif hasattr(metadata, "SerializeToString"): @@ -709,6 +706,10 @@ def write(tag, tensor, step=None, metadata=None, name=None): def record(): """Record the actual summary and return True.""" + if step is None: + raise ValueError("No step set via 'step' argument or " + "tf.summary.experimental.set_step()") + # Note the identity to move the tensor to the CPU. with ops.device("cpu:0"): summary_tensor = tensor() if callable(tensor) else array_ops.identity(