Switch legalization status type to match old bridge
* Switch to more appropriate/matching old bridge invalid arg for failed legalization; * Also enable some function jit tests which just have different error messages; PiperOrigin-RevId: 356283724 Change-Id: I2af765ad220e483d45737245427e321fa3b7fddb
This commit is contained in:
parent
4cf728b600
commit
0921a80d56
@ -360,7 +360,7 @@ Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
|
||||
if (failed(tf2xla.run(module_op))) {
|
||||
return error_handler.Combine(
|
||||
errors::Internal("MLIR TF to XLA legalization failed"));
|
||||
errors::InvalidArgument("TF to XLA legalization failed"));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1))
|
||||
|
@ -131,8 +131,6 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
self.assertAllClose([2, 3, 3, 4, 4], fn2(inputs, 1))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns'
|
||||
' wrong status type')
|
||||
def testNestedCallUnsupportedOps(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
@ -146,12 +144,11 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
func = def_function.function(fn2, jit_compile=False)
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'not compilable'):
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError, 'legalization failed'
|
||||
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
|
||||
func(inputs)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns'
|
||||
' wrong status type')
|
||||
def testUnsupportedOps(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('XLA TPU supports tf.unique')
|
||||
@ -166,8 +163,9 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
self.assertAllClose([1, 2, 3], func(inputs))
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'not compilable'):
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError, 'legalization failed'
|
||||
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
|
||||
xla_func(inputs)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
@ -200,8 +198,6 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
self.assertIn('def_function_xla_jit_test',
|
||||
g.experimental_get_compiler_ir(inputs, inputs)())
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
def testPythonStackTrace(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('XLA TPU supports tf.unique')
|
||||
@ -216,8 +212,6 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT2'):
|
||||
fn(inputs)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
def testPythonStackTraceControlFlow(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('XLA TPU supports tf.unique')
|
||||
@ -240,8 +234,6 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError, r'\.y\[0\]'):
|
||||
f(constant_op.constant(100.0))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
|
||||
'support stack traces')
|
||||
def testPythonStackTraceUncompiledWithinCompiled(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('XLA TPU supports tf.unique')
|
||||
@ -373,8 +365,6 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
c = C()
|
||||
self.assertAllClose([2, 3, 3, 4, 4], c.f1(inputs, 1))
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns '
|
||||
' wrong status type')
|
||||
def testMethodCompilationUnsupportedFunc(self):
|
||||
if 'tpu' in self.device.lower():
|
||||
self.skipTest('XLA TPU supports tf.unique')
|
||||
@ -389,8 +379,9 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
|
||||
inputs = constant_op.constant([1, 2, 2, 3, 3])
|
||||
c = C()
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'not compilable'):
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError, 'legalization failed'
|
||||
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
|
||||
c.f1(inputs)
|
||||
|
||||
def testMustBeConstantPropagation(self):
|
||||
@ -900,7 +891,7 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
return ta.concat() # EXPECTED_MESSAGE_OLD
|
||||
|
||||
if test_util.is_mlir_bridge_enabled():
|
||||
with self.assertRaisesRegex(errors.InternalError,
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'EXPECTED_MESSAGE_NEW'):
|
||||
f()
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user