From 62bc872eef83fe8300823bdbd60c716ad58fecef Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Fri, 23 Oct 2020 18:12:42 -0700 Subject: [PATCH] Enable TridiagonalSolve and XlaSort op for fallback path lowering Also, enable compiler/tests for reduce_ops_test PiperOrigin-RevId: 338778787 Change-Id: I7db2ff2281f6404884fcf12204042a70b97a6782 --- .../xla/transforms/legalize_tf_with_tf2xla.cc | 4 +++- tensorflow/compiler/tests/BUILD | 3 +++ tensorflow/compiler/tests/reduce_ops_test.py | 17 +++++++++++------ tensorflow/compiler/tests/sort_ops_test.py | 2 ++ .../tests/tridiagonal_solve_ops_test.py | 2 ++ 5 files changed, 21 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 5098e581fd6..e87e86ac45a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -232,6 +232,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -247,7 +248,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), - TypeID::get() + TypeID::get(), + TypeID::get() }; // clang-format on diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 058891721db..0cbb5698f4a 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1087,6 +1087,7 @@ tf_xla_py_test( name = "reduce_ops_test", size = "medium", srcs = ["reduce_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1546,6 +1547,7 @@ tf_xla_py_test( name = "sort_ops_test", size = "medium", srcs = ["sort_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 1, # Times out in fastbuild mode. @@ -1889,6 +1891,7 @@ tf_xla_py_test( "cpu", "gpu", ], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/reduce_ops_test.py b/tensorflow/compiler/tests/reduce_ops_test.py index eb46c536e07..b8909608823 100644 --- a/tensorflow/compiler/tests/reduce_ops_test.py +++ b/tensorflow/compiler/tests/reduce_ops_test.py @@ -27,6 +27,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest @@ -63,13 +64,16 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): self.assertAllClose( result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol) - with self.assertRaisesWithPredicateMatch( - errors_impl.InvalidArgumentError, 'Invalid reduction dim'): - sess.run(out, {a: test_input, index: [-33]}) + # MLIR bridge doesn't return the same error so it can't be matched + # directly. + if not test_util.is_mlir_bridge_enabled(): + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, 'Invalid reduction dim'): + sess.run(out, {a: test_input, index: [-33]}) - with self.assertRaisesWithPredicateMatch( - errors_impl.InvalidArgumentError, 'Invalid reduction dim'): - sess.run(out, {a: test_input, index: [2]}) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, 'Invalid reduction dim'): + sess.run(out, {a: test_input, index: [2]}) REAL_DATA = [ np.zeros(shape=(2, 0)), @@ -168,6 +172,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase): self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA, index_dtype) + @test_util.disable_mlir_bridge('Error messages differ') def testReduceSumWithDuplicateAxes(self, index_dtype): with self.session() as sess: with self.test_scope(): diff --git a/tensorflow/compiler/tests/sort_ops_test.py b/tensorflow/compiler/tests/sort_ops_test.py index 838718aa1e3..847f1aa8cda 100644 --- a/tensorflow/compiler/tests/sort_ops_test.py +++ b/tensorflow/compiler/tests/sort_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tf2xla.python import xla from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops from tensorflow.python.platform import test @@ -127,6 +128,7 @@ class XlaSortOpTest(xla_test.XLATestCase): topk, [x.astype(dtype)], expected=[expected.astype(dtype), indices]) + @test_util.disable_mlir_bridge("Support compare type in HLO Compare Op") def testTopKZeros(self): """Tests that positive and negative zeros sort correctly.""" supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32]) diff --git a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py index e462211e5dd..ca50916dcca 100644 --- a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py +++ b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients as gradient_ops from tensorflow.python.ops import math_ops @@ -211,6 +212,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase): # test2x2NotInvertible is skipped as runtime error not raised for now. + @test_util.disable_mlir_bridge("Error messages differ") def testPartialPivotingRaises(self): np.random.seed(0) batch_size = 8