Add simple AutomaticOutsideCompilation test with scalar summary.
PiperOrigin-RevId: 325285306 Change-Id: I7dfac98cc2dc592af346c7229a298b48df568845
This commit is contained in:
parent
c1244778c1
commit
4d59bcb41e
@ -19,10 +19,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.core.util import event_pb2
|
||||
from tensorflow.python.distribute import tpu_strategy as tpu_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
from tensorflow.python.eager import def_function
|
||||
@ -30,6 +32,7 @@ from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -40,6 +43,7 @@ from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import summary_ops_v2 as summary
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
|
||||
@ -70,6 +74,20 @@ def computation_with_string_ops(x):
|
||||
return string_ops.string_to_number(output)
|
||||
|
||||
|
||||
def _events_from_logdir(test_case, logdir):
|
||||
"""Reads summary events from log directory."""
|
||||
test_case.assertTrue(gfile.Exists(logdir))
|
||||
files = gfile.ListDirectory(logdir)
|
||||
test_case.assertLen(files, 1)
|
||||
records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0])))
|
||||
result = []
|
||||
for r in records:
|
||||
event = event_pb2.Event()
|
||||
event.ParseFromString(r)
|
||||
result.append(event)
|
||||
return result
|
||||
|
||||
|
||||
class TpuOutsideCompilationTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testResourceVariableAssignOnHost(self):
|
||||
@ -488,6 +506,36 @@ class OutsideCompilationOnUnsupportedOpTest(test.TestCase):
|
||||
strategy.experimental_local_results(train_step(0)),
|
||||
constant_op.constant(10, shape=(strategy.num_replicas_in_sync)))
|
||||
|
||||
def testSummaryWithAutoOutsideCompilation(self):
|
||||
strategy = get_tpu_strategy()
|
||||
|
||||
def host_computation(x):
|
||||
summary.scalar("x", x, step=0)
|
||||
return x * 2.0
|
||||
|
||||
@def_function.function
|
||||
def step():
|
||||
|
||||
def computation(x):
|
||||
x = x + 1.0
|
||||
y = host_computation(x)
|
||||
return y + 1.0
|
||||
|
||||
return strategy.run(computation, args=(2.0,))
|
||||
|
||||
logdir = tempfile.mkdtemp()
|
||||
summary_writer = summary.create_file_writer(logdir, flush_millis=10000)
|
||||
with summary_writer.as_default(), summary.always_record_summaries():
|
||||
self.assertAllEqual(
|
||||
strategy.experimental_local_results(step()),
|
||||
constant_op.constant(7., shape=(strategy.num_replicas_in_sync)))
|
||||
events = _events_from_logdir(self, logdir)
|
||||
# There will be 2 entries: 1 summary file header entry, and 1 entry
|
||||
# written by host.
|
||||
self.assertLen(events, 2)
|
||||
self.assertEqual(events[1].summary.value[0].tag, "x")
|
||||
self.assertEqual(events[1].summary.value[0].simple_value, 3.0)
|
||||
|
||||
def testAutoOutsideCompilationWithFunctionalNodes(self):
|
||||
strategy = get_tpu_strategy()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user