From e65b8855b1bfa11fab09790fe0532b57b2ed56f2 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Fri, 9 Oct 2020 13:31:03 -0700 Subject: [PATCH] Add outside compilation test for histogram summary that doesn't use tail extraction. PiperOrigin-RevId: 336354849 Change-Id: Iec2a8dbcca2d979317da151a56dddd4ccb73fc76 --- .../tpu/tpu_outside_compilation_test.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tensorflow/python/tpu/tpu_outside_compilation_test.py b/tensorflow/python/tpu/tpu_outside_compilation_test.py index 1385bbd2be7..4eb6429f3c8 100644 --- a/tensorflow/python/tpu/tpu_outside_compilation_test.py +++ b/tensorflow/python/tpu/tpu_outside_compilation_test.py @@ -24,6 +24,7 @@ import tempfile from absl.testing import parameterized import numpy as np +from tensorboard.plugins.histogram import summary_v2 as histogram_summary_v2 from tensorboard.plugins.scalar import summary_v2 as scalar_summary_v2 from tensorflow.core.util import event_pb2 from tensorflow.python.distribute import tpu_strategy as tpu_lib @@ -516,6 +517,35 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase, self.assertLen(events, 2) self.assertEqual(events[1].summary.value[0].tag, "x") + def testHistogramSummaryWithAutoOutsideCompilation(self): + strategy = get_tpu_strategy() + + def host_computation(x): + histogram_summary_v2.histogram("x", x, step=0) + return x * 2.0 + + @def_function.function + def step(): + + def computation(x): + x = x + 1.0 + y = host_computation(x) + return y + 1.0 + + return strategy.run(computation, args=(2.0,)) + + logdir = tempfile.mkdtemp() + summary_writer = summary.create_file_writer(logdir, flush_millis=10000) + with summary_writer.as_default(), summary.always_record_summaries(): + self.assertAllEqual( + strategy.experimental_local_results(step()), + constant_op.constant(7., shape=(strategy.num_replicas_in_sync))) + events = _events_from_logdir(self, logdir) + # There will be 2 entries: 1 summary file header entry, and 1 entry + # written by host. + self.assertLen(events, 2) + self.assertEqual(events[1].summary.value[0].tag, "x") + @parameterized.parameters((True), (False)) def testSummaryControlFlowIfWithAutoOutsideCompilation( self, take_true_branch):