Make compilation error check generic.

TF2XLA and MLIR bridge return different error types when compilation fails. Makes the check more generic.  Also renames experimental_run -> run in test names.

PiperOrigin-RevId: 337935812
Change-Id: I453f79ddda4ae9f2a184936683e70969fb7ffedd
This commit is contained in:
Ken Franko 2020-10-19 14:34:59 -07:00 committed by TensorFlower Gardener
parent 209ede00c5
commit f0bfa71e14

View File

@ -179,7 +179,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
with ops.device("/device:TPU:0"):
self.assertAllEqual(func(), 2.0)
def test_sequential_experimental_runs(self, enable_packed_var):
def test_sequential_runs(self, enable_packed_var):
resolver = get_tpu_cluster_resolver()
remote.connect_to_cluster(resolver)
topology = tpu_strategy_util.initialize_tpu_system(resolver)
@ -254,8 +254,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
return strategy.run(computation)
with self.assertRaisesRegex(errors.InvalidArgumentError,
"TPU compilation failed"):
with self.assertRaises(errors.OpError):
compilation_failure_run()
@def_function.function
@ -476,7 +475,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual(expected_result, run(input_iterator))
self.assertAllEqual((0.,), w.read_value())
def test_experimental_run_output_on_device(self, enable_packed_var):
def test_run_output_on_device(self, enable_packed_var):
strategy = get_tpu_strategy(enable_packed_var)
def computation(x):