Add unit test for outside compilation with summary inside control flow.
This covers automatic outside compilation with If control flow and summary ops. PiperOrigin-RevId: 326345698 Change-Id: Id308670fd002036ffdc51e64d5995d057d157493
This commit is contained in:
parent
3ea5fc7f3f
commit
730b802ee8
@ -470,7 +470,8 @@ class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
|
|||||||
constant_op.constant(2916., shape=(strategy.num_replicas_in_sync)))
|
constant_op.constant(2916., shape=(strategy.num_replicas_in_sync)))
|
||||||
|
|
||||||
|
|
||||||
class OutsideCompilationOnUnsupportedOpTest(test.TestCase):
|
class OutsideCompilationOnUnsupportedOpTest(test.TestCase,
|
||||||
|
parameterized.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super(OutsideCompilationOnUnsupportedOpTest, self).setUp()
|
super(OutsideCompilationOnUnsupportedOpTest, self).setUp()
|
||||||
@ -536,6 +537,45 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase):
|
|||||||
self.assertEqual(events[1].summary.value[0].tag, "x")
|
self.assertEqual(events[1].summary.value[0].tag, "x")
|
||||||
self.assertEqual(events[1].summary.value[0].simple_value, 3.0)
|
self.assertEqual(events[1].summary.value[0].simple_value, 3.0)
|
||||||
|
|
||||||
|
@parameterized.parameters((True), (False))
|
||||||
|
def testSummaryControlFlowIfWithAutoOutsideCompilation(
|
||||||
|
self, take_true_branch):
|
||||||
|
strategy = get_tpu_strategy()
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def step():
|
||||||
|
|
||||||
|
def computation(x):
|
||||||
|
x = x + 1.0
|
||||||
|
if x < 5:
|
||||||
|
summary.scalar("x", x, step=0)
|
||||||
|
x = x * 2.0
|
||||||
|
return x + 1.0
|
||||||
|
|
||||||
|
if take_true_branch:
|
||||||
|
return strategy.run(computation, args=(2.0,))
|
||||||
|
else:
|
||||||
|
return strategy.run(computation, args=(10.0,))
|
||||||
|
|
||||||
|
logdir = tempfile.mkdtemp()
|
||||||
|
summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
|
||||||
|
output_value = 12.
|
||||||
|
if take_true_branch:
|
||||||
|
output_value = 7.
|
||||||
|
with summary_writer.as_default(), summary.always_record_summaries():
|
||||||
|
self.assertAllEqual(
|
||||||
|
strategy.experimental_local_results(step()),
|
||||||
|
constant_op.constant(
|
||||||
|
output_value, shape=(strategy.num_replicas_in_sync)))
|
||||||
|
if take_true_branch:
|
||||||
|
events = _events_from_logdir(self, logdir)
|
||||||
|
# There will be 2 entries: 1 summary file header entry, and 1 entry
|
||||||
|
# written by host.
|
||||||
|
#
|
||||||
|
self.assertLen(events, 2)
|
||||||
|
self.assertEqual(events[1].summary.value[0].tag, "cond/x")
|
||||||
|
self.assertEqual(events[1].summary.value[0].simple_value, 3.0)
|
||||||
|
|
||||||
def testAutoOutsideCompilationWithFunctionalNodes(self):
|
def testAutoOutsideCompilationWithFunctionalNodes(self):
|
||||||
strategy = get_tpu_strategy()
|
strategy = get_tpu_strategy()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user