Make compilation error check generic.

TF2XLA and MLIR bridge return different error types when compilation fails. Makes the check more generic.  Also renames experimental_run -> run in test names.

PiperOrigin-RevId: 337935812
Change-Id: I453f79ddda4ae9f2a184936683e70969fb7ffedd
This commit is contained in:
Ken Franko 2020-10-19 14:34:59 -07:00 committed by TensorFlower Gardener
parent 209ede00c5
commit f0bfa71e14

View File

@ -179,7 +179,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
with ops.device("/device:TPU:0"): with ops.device("/device:TPU:0"):
self.assertAllEqual(func(), 2.0) self.assertAllEqual(func(), 2.0)
def test_sequential_experimental_runs(self, enable_packed_var): def test_sequential_runs(self, enable_packed_var):
resolver = get_tpu_cluster_resolver() resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver) remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver) topology = tpu_strategy_util.initialize_tpu_system(resolver)
@ -254,8 +254,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
return strategy.run(computation) return strategy.run(computation)
with self.assertRaisesRegex(errors.InvalidArgumentError, with self.assertRaises(errors.OpError):
"TPU compilation failed"):
compilation_failure_run() compilation_failure_run()
@def_function.function @def_function.function
@ -476,7 +475,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(expected_result, run(input_iterator)) self.assertAllEqual(expected_result, run(input_iterator))
self.assertAllEqual((0.,), w.read_value()) self.assertAllEqual((0.,), w.read_value())
def test_experimental_run_output_on_device(self, enable_packed_var): def test_run_output_on_device(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var) strategy = get_tpu_strategy(enable_packed_var)
def computation(x): def computation(x):