From 730b802ee896dc410a1517b3ceaf7cc7183dabdb Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Wed, 12 Aug 2020 17:09:55 -0700 Subject: [PATCH] Add unit test for outside compilation with summary inside control flow. This covers automatic outside compilation with If control flow and summary ops. PiperOrigin-RevId: 326345698 Change-Id: Id308670fd002036ffdc51e64d5995d057d157493 --- .../tpu/tpu_outside_compilation_test.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index 72e9f10d184..7e0278aa343 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -470,7 +470,8 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase): constant_op.constant(2916., shape=(strategy.num_replicas_in_sync))) -class OutsideCompilationOnUnsupportedOpTest(test.TestCase): +class OutsideCompilationOnUnsupportedOpTest(test.TestCase, + parameterized.TestCase): def setUp(self): super(OutsideCompilationOnUnsupportedOpTest, self).setUp() @@ -536,6 +537,45 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase): self.assertEqual(events[1].summary.value[0].tag, "x") self.assertEqual(events[1].summary.value[0].simple_value, 3.0) + @parameterized.parameters((True), (False)) + def testSummaryControlFlowIfWithAutoOutsideCompilation( + self, take_true_branch): + strategy = get_tpu_strategy() + + @def_function.function + def step(): + + def computation(x): + x = x + 1.0 + if x < 5: + summary.scalar("x", x, step=0) + x = x * 2.0 + return x + 1.0 + + if take_true_branch: + return strategy.run(computation, args=(2.0,)) + else: + return strategy.run(computation, args=(10.0,)) + + logdir = tempfile.mkdtemp() + summary_writer = summary.create_file_writer(logdir, flush_millis=10000) + output_value = 12. + if take_true_branch: + output_value = 7. + with summary_writer.as_default(), summary.always_record_summaries(): + self.assertAllEqual( + strategy.experimental_local_results(step()), + constant_op.constant( + output_value, shape=(strategy.num_replicas_in_sync))) + if take_true_branch: + events = _events_from_logdir(self, logdir) + # There will be 2 entries: 1 summary file header entry, and 1 entry + # written by host. + # + self.assertLen(events, 2) + self.assertEqual(events[1].summary.value[0].tag, "cond/x") + self.assertEqual(events[1].summary.value[0].simple_value, 3.0) + def testAutoOutsideCompilationWithFunctionalNodes(self): strategy = get_tpu_strategy()