Switch legalization status type to match old bridge

* Switch to more appropriate/matching old bridge invalid arg for failed legalization;
* Also enable some function jit tests which just have different error messages;

PiperOrigin-RevId: 356283724
Change-Id: I2af765ad220e483d45737245427e321fa3b7fddb
This commit is contained in:
Jacques Pienaar 2021-02-08 09:32:04 -08:00 committed by TensorFlower Gardener
parent 4cf728b600
commit 0921a80d56
2 changed files with 11 additions and 20 deletions

View File

@ -360,7 +360,7 @@ Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
if (failed(tf2xla.run(module_op))) {
return error_handler.Combine(
errors::Internal("MLIR TF to XLA legalization failed"));
errors::InvalidArgument("TF to XLA legalization failed"));
}
if (VLOG_IS_ON(1))

View File

@ -131,8 +131,6 @@ class DefFunctionTest(xla_test.XLATestCase):
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([2, 3, 3, 4, 4], fn2(inputs, 1))
@test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns'
' wrong status type')
def testNestedCallUnsupportedOps(self):
with ops.device('device:{}:0'.format(self.device)):
@ -146,12 +144,11 @@ class DefFunctionTest(xla_test.XLATestCase):
func = def_function.function(fn2, jit_compile=False)
inputs = constant_op.constant([1, 2, 2, 3, 3])
with self.assertRaisesRegex(errors.InvalidArgumentError,
'not compilable'):
with self.assertRaisesRegex(
errors.InvalidArgumentError, 'legalization failed'
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
func(inputs)
@test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns'
' wrong status type')
def testUnsupportedOps(self):
if 'tpu' in self.device.lower():
self.skipTest('XLA TPU supports tf.unique')
@ -166,8 +163,9 @@ class DefFunctionTest(xla_test.XLATestCase):
inputs = constant_op.constant([1, 2, 2, 3, 3])
self.assertAllClose([1, 2, 3], func(inputs))
with self.assertRaisesRegex(errors.InvalidArgumentError,
'not compilable'):
with self.assertRaisesRegex(
errors.InvalidArgumentError, 'legalization failed'
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
xla_func(inputs)
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
@ -200,8 +198,6 @@ class DefFunctionTest(xla_test.XLATestCase):
self.assertIn('def_function_xla_jit_test',
g.experimental_get_compiler_ir(inputs, inputs)())
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces')
def testPythonStackTrace(self):
if 'tpu' in self.device.lower():
self.skipTest('XLA TPU supports tf.unique')
@ -216,8 +212,6 @@ class DefFunctionTest(xla_test.XLATestCase):
with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT2'):
fn(inputs)
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces')
def testPythonStackTraceControlFlow(self):
if 'tpu' in self.device.lower():
self.skipTest('XLA TPU supports tf.unique')
@ -240,8 +234,6 @@ class DefFunctionTest(xla_test.XLATestCase):
with self.assertRaisesRegex(errors.InvalidArgumentError, r'\.y\[0\]'):
f(constant_op.constant(100.0))
@test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
'support stack traces')
def testPythonStackTraceUncompiledWithinCompiled(self):
if 'tpu' in self.device.lower():
self.skipTest('XLA TPU supports tf.unique')
@ -373,8 +365,6 @@ class DefFunctionTest(xla_test.XLATestCase):
c = C()
self.assertAllClose([2, 3, 3, 4, 4], c.f1(inputs, 1))
@test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns '
' wrong status type')
def testMethodCompilationUnsupportedFunc(self):
if 'tpu' in self.device.lower():
self.skipTest('XLA TPU supports tf.unique')
@ -389,8 +379,9 @@ class DefFunctionTest(xla_test.XLATestCase):
inputs = constant_op.constant([1, 2, 2, 3, 3])
c = C()
with self.assertRaisesRegex(errors.InvalidArgumentError,
'not compilable'):
with self.assertRaisesRegex(
errors.InvalidArgumentError, 'legalization failed'
if test_util.is_mlir_bridge_enabled() else 'not compilable'):
c.f1(inputs)
def testMustBeConstantPropagation(self):
@ -900,7 +891,7 @@ class DefFunctionTest(xla_test.XLATestCase):
return ta.concat() # EXPECTED_MESSAGE_OLD
if test_util.is_mlir_bridge_enabled():
with self.assertRaisesRegex(errors.InternalError,
with self.assertRaisesRegex(errors.InvalidArgumentError,
'EXPECTED_MESSAGE_NEW'):
f()
else: