Add test with nested tf.function and outside compilation.

PiperOrigin-RevId: 359156369
Change-Id: I083557efbd5d1dc7642244d666361b12e21960f7
This commit is contained in:
Ken Franko 2021-02-23 16:03:11 -08:00 committed by TensorFlower Gardener
parent 2edf9157cf
commit fe588d3a3d

View File

@ -582,6 +582,36 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
self.assertLen(events, 2)
self.assertEqual(events[1].summary.value[0].tag, "x")
def testNestedFunctionScalarSummary(self):
strategy = get_tpu_strategy()
def host_computation(x):
scalar_summary_v2.scalar("x", x, step=0)
return x * 2.0
@def_function.function
def step():
@def_function.function
def computation(x):
x = x + 1.0
y = host_computation(x)
return y + 1.0
return strategy.run(computation, args=(2.0,))
logdir = tempfile.mkdtemp()
summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
with summary_writer.as_default(), summary.always_record_summaries():
self.assertAllEqual(
strategy.experimental_local_results(step()),
constant_op.constant(7., shape=(strategy.num_replicas_in_sync)))
events = _events_from_logdir(self, logdir)
# There will be 2 entries: 1 summary file header entry, and 1 entry
# written by host.
self.assertLen(events, 2)
self.assertEqual(events[1].summary.value[0].tag, "x")
def testHistogramSummaryWithAutoOutsideCompilation(self):
strategy = get_tpu_strategy()