Make compilation error check generic.
TF2XLA and MLIR bridge return different error types when compilation fails. Makes the check more generic. Also renames experimental_run -> run in test names. PiperOrigin-RevId: 337935812 Change-Id: I453f79ddda4ae9f2a184936683e70969fb7ffedd
This commit is contained in:
parent
209ede00c5
commit
f0bfa71e14
@ -179,7 +179,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
with ops.device("/device:TPU:0"):
|
||||
self.assertAllEqual(func(), 2.0)
|
||||
|
||||
def test_sequential_experimental_runs(self, enable_packed_var):
|
||||
def test_sequential_runs(self, enable_packed_var):
|
||||
resolver = get_tpu_cluster_resolver()
|
||||
remote.connect_to_cluster(resolver)
|
||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||
@ -254,8 +254,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
return strategy.run(computation)
|
||||
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"TPU compilation failed"):
|
||||
with self.assertRaises(errors.OpError):
|
||||
compilation_failure_run()
|
||||
|
||||
@def_function.function
|
||||
@ -476,7 +475,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertAllEqual(expected_result, run(input_iterator))
|
||||
self.assertAllEqual((0.,), w.read_value())
|
||||
|
||||
def test_experimental_run_output_on_device(self, enable_packed_var):
|
||||
def test_run_output_on_device(self, enable_packed_var):
|
||||
strategy = get_tpu_strategy(enable_packed_var)
|
||||
|
||||
def computation(x):
|
||||
|
Loading…
Reference in New Issue
Block a user