Make compilation error check generic.
TF2XLA and MLIR bridge return different error types when compilation fails. Makes the check more generic. Also renames experimental_run -> run in test names. PiperOrigin-RevId: 337935812 Change-Id: I453f79ddda4ae9f2a184936683e70969fb7ffedd
This commit is contained in:
parent
209ede00c5
commit
f0bfa71e14
@ -179,7 +179,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
|||||||
with ops.device("/device:TPU:0"):
|
with ops.device("/device:TPU:0"):
|
||||||
self.assertAllEqual(func(), 2.0)
|
self.assertAllEqual(func(), 2.0)
|
||||||
|
|
||||||
def test_sequential_experimental_runs(self, enable_packed_var):
|
def test_sequential_runs(self, enable_packed_var):
|
||||||
resolver = get_tpu_cluster_resolver()
|
resolver = get_tpu_cluster_resolver()
|
||||||
remote.connect_to_cluster(resolver)
|
remote.connect_to_cluster(resolver)
|
||||||
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
topology = tpu_strategy_util.initialize_tpu_system(resolver)
|
||||||
@ -254,8 +254,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
return strategy.run(computation)
|
return strategy.run(computation)
|
||||||
|
|
||||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
with self.assertRaises(errors.OpError):
|
||||||
"TPU compilation failed"):
|
|
||||||
compilation_failure_run()
|
compilation_failure_run()
|
||||||
|
|
||||||
@def_function.function
|
@def_function.function
|
||||||
@ -476,7 +475,7 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertAllEqual(expected_result, run(input_iterator))
|
self.assertAllEqual(expected_result, run(input_iterator))
|
||||||
self.assertAllEqual((0.,), w.read_value())
|
self.assertAllEqual((0.,), w.read_value())
|
||||||
|
|
||||||
def test_experimental_run_output_on_device(self, enable_packed_var):
|
def test_run_output_on_device(self, enable_packed_var):
|
||||||
strategy = get_tpu_strategy(enable_packed_var)
|
strategy = get_tpu_strategy(enable_packed_var)
|
||||||
|
|
||||||
def computation(x):
|
def computation(x):
|
||||||
|
Loading…
Reference in New Issue
Block a user