Add test for recovering from compile errors w/ TPUStrategy.

PiperOrigin-RevId: 298927832
Change-Id: I393e73a0869fa51f8c7c90bbfbdc4c1fa8ba521e
This commit is contained in:
Ken Franko 2020-03-04 13:36:06 -08:00 committed by TensorFlower Gardener
parent f7d4c7ffd5
commit f6793de3fe

View File

@ -18,9 +18,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import tpu_strategy as tpu_lib
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import errors
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import tpu_strategy_util
@ -32,6 +36,17 @@ flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
def get_tpu_strategy():
resolver = tpu_cluster_resolver.TPUClusterResolver(
tpu=FLAGS.tpu,
zone=FLAGS.zone,
project=FLAGS.project,
)
remote.connect_to_cluster(resolver)
tpu_strategy_util.initialize_tpu_system(resolver)
return tpu_lib.TPUStrategy(resolver)
class TpuStrategyTest(test.TestCase):
def test_multiple_initialize_system(self):
@ -47,6 +62,33 @@ class TpuStrategyTest(test.TestCase):
tpu_strategy_util.initialize_tpu_system(resolver)
self.assertRegex(str(mock_log.call_args), "already been initialized")
def test_recover_from_compilation_failures(self):
strategy = get_tpu_strategy()
@def_function.function
def compilation_failure_run():
def computation():
samples = random_ops.random_gamma([10], [0.5, 1.5])
return samples
return strategy.experimental_run_v2(computation)
with self.assertRaisesRegexp(errors.InvalidArgumentError,
"TPU compilation failed"):
compilation_failure_run()
@def_function.function
def good_run():
def computation():
samples = random_ops.random_normal([10])
return samples
return strategy.experimental_run_v2(computation)
good_run()
if __name__ == "__main__":
test.main()