diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 47738d83e05..cc17993be37 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -360,7 +360,7 @@ Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type, if (failed(tf2xla.run(module_op))) { return error_handler.Combine( - errors::Internal("MLIR TF to XLA legalization failed")); + errors::InvalidArgument("TF to XLA legalization failed")); } if (VLOG_IS_ON(1)) diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 88f0db3755b..80c13eaa435 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -131,8 +131,6 @@ class DefFunctionTest(xla_test.XLATestCase): inputs = constant_op.constant([1, 2, 2, 3, 3]) self.assertAllClose([2, 3, 3, 4, 4], fn2(inputs, 1)) - @test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns' - ' wrong status type') def testNestedCallUnsupportedOps(self): with ops.device('device:{}:0'.format(self.device)): @@ -146,12 +144,11 @@ class DefFunctionTest(xla_test.XLATestCase): func = def_function.function(fn2, jit_compile=False) inputs = constant_op.constant([1, 2, 2, 3, 3]) - with self.assertRaisesRegex(errors.InvalidArgumentError, - 'not compilable'): + with self.assertRaisesRegex( + errors.InvalidArgumentError, 'legalization failed' + if test_util.is_mlir_bridge_enabled() else 'not compilable'): func(inputs) - @test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns' - ' wrong status type') def testUnsupportedOps(self): if 'tpu' in self.device.lower(): self.skipTest('XLA TPU supports tf.unique') @@ -166,8 +163,9 @@ class DefFunctionTest(xla_test.XLATestCase): inputs = constant_op.constant([1, 2, 2, 3, 3]) self.assertAllClose([1, 2, 3], func(inputs)) - with self.assertRaisesRegex(errors.InvalidArgumentError, - 'not compilable'): + with self.assertRaisesRegex( + errors.InvalidArgumentError, 'legalization failed' + if test_util.is_mlir_bridge_enabled() else 'not compilable'): xla_func(inputs) @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' @@ -200,8 +198,6 @@ class DefFunctionTest(xla_test.XLATestCase): self.assertIn('def_function_xla_jit_test', g.experimental_get_compiler_ir(inputs, inputs)()) - @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' - 'support stack traces') def testPythonStackTrace(self): if 'tpu' in self.device.lower(): self.skipTest('XLA TPU supports tf.unique') @@ -216,8 +212,6 @@ class DefFunctionTest(xla_test.XLATestCase): with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT2'): fn(inputs) - @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' - 'support stack traces') def testPythonStackTraceControlFlow(self): if 'tpu' in self.device.lower(): self.skipTest('XLA TPU supports tf.unique') @@ -240,8 +234,6 @@ class DefFunctionTest(xla_test.XLATestCase): with self.assertRaisesRegex(errors.InvalidArgumentError, r'\.y\[0\]'): f(constant_op.constant(100.0)) - @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' - 'support stack traces') def testPythonStackTraceUncompiledWithinCompiled(self): if 'tpu' in self.device.lower(): self.skipTest('XLA TPU supports tf.unique') @@ -373,8 +365,6 @@ class DefFunctionTest(xla_test.XLATestCase): c = C() self.assertAllClose([2, 3, 3, 4, 4], c.f1(inputs, 1)) - @test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns ' - ' wrong status type') def testMethodCompilationUnsupportedFunc(self): if 'tpu' in self.device.lower(): self.skipTest('XLA TPU supports tf.unique') @@ -389,8 +379,9 @@ class DefFunctionTest(xla_test.XLATestCase): inputs = constant_op.constant([1, 2, 2, 3, 3]) c = C() - with self.assertRaisesRegex(errors.InvalidArgumentError, - 'not compilable'): + with self.assertRaisesRegex( + errors.InvalidArgumentError, 'legalization failed' + if test_util.is_mlir_bridge_enabled() else 'not compilable'): c.f1(inputs) def testMustBeConstantPropagation(self): @@ -900,7 +891,7 @@ class DefFunctionTest(xla_test.XLATestCase): return ta.concat() # EXPECTED_MESSAGE_OLD if test_util.is_mlir_bridge_enabled(): - with self.assertRaisesRegex(errors.InternalError, + with self.assertRaisesRegex(errors.InvalidArgumentError, 'EXPECTED_MESSAGE_NEW'): f() else: