Add test for recovering from compile errors w/ TPUStrategy.
PiperOrigin-RevId: 298927832 Change-Id: I393e73a0869fa51f8c7c90bbfbdc4c1fa8ba521e
This commit is contained in:
parent
f7d4c7ffd5
commit
f6793de3fe
@ -18,9 +18,13 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import tpu_strategy as tpu_lib
|
||||
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import remote
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
@ -32,6 +36,17 @@ flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
|
||||
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")
|
||||
|
||||
|
||||
def get_tpu_strategy():
|
||||
resolver = tpu_cluster_resolver.TPUClusterResolver(
|
||||
tpu=FLAGS.tpu,
|
||||
zone=FLAGS.zone,
|
||||
project=FLAGS.project,
|
||||
)
|
||||
remote.connect_to_cluster(resolver)
|
||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
return tpu_lib.TPUStrategy(resolver)
|
||||
|
||||
|
||||
class TpuStrategyTest(test.TestCase):
|
||||
|
||||
def test_multiple_initialize_system(self):
|
||||
@ -47,6 +62,33 @@ class TpuStrategyTest(test.TestCase):
|
||||
tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
self.assertRegex(str(mock_log.call_args), "already been initialized")
|
||||
|
||||
def test_recover_from_compilation_failures(self):
|
||||
strategy = get_tpu_strategy()
|
||||
|
||||
@def_function.function
|
||||
def compilation_failure_run():
|
||||
|
||||
def computation():
|
||||
samples = random_ops.random_gamma([10], [0.5, 1.5])
|
||||
return samples
|
||||
|
||||
return strategy.experimental_run_v2(computation)
|
||||
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"TPU compilation failed"):
|
||||
compilation_failure_run()
|
||||
|
||||
@def_function.function
|
||||
def good_run():
|
||||
|
||||
def computation():
|
||||
samples = random_ops.random_normal([10])
|
||||
return samples
|
||||
|
||||
return strategy.experimental_run_v2(computation)
|
||||
|
||||
good_run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user