Enable TridiagonalSolve and XlaSort op for fallback path lowering
Also, enable compiler/tests for reduce_ops_test PiperOrigin-RevId: 338778787 Change-Id: I7db2ff2281f6404884fcf12204042a70b97a6782
This commit is contained in:
parent
e7b54fbda2
commit
62bc872eef
@ -232,6 +232,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::TensorScatterSubOp>(),
|
||||
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
||||
TypeID::get<TF::TransposeOp>(),
|
||||
TypeID::get<TF::TridiagonalSolveOp>(),
|
||||
TypeID::get<TF::TruncateDivOp>(),
|
||||
TypeID::get<TF::TruncatedNormalOp>(),
|
||||
TypeID::get<TF::TruncateModOp>(),
|
||||
@ -247,7 +248,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::XlaKeyValueSortOp>(),
|
||||
TypeID::get<TF::XlaPadOp>(),
|
||||
TypeID::get<TF::Xlog1pyOp>(),
|
||||
TypeID::get<TF::XlogyOp>()
|
||||
TypeID::get<TF::XlogyOp>(),
|
||||
TypeID::get<TF::XlaSortOp>()
|
||||
};
|
||||
// clang-format on
|
||||
|
||||
|
@ -1087,6 +1087,7 @@ tf_xla_py_test(
|
||||
name = "reduce_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["reduce_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
@ -1546,6 +1547,7 @@ tf_xla_py_test(
|
||||
name = "sort_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["sort_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 1,
|
||||
# Times out in fastbuild mode.
|
||||
@ -1889,6 +1891,7 @@ tf_xla_py_test(
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -27,6 +27,7 @@ import numpy as np
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
@ -63,13 +64,16 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
self.assertAllClose(
|
||||
result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
|
||||
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||
sess.run(out, {a: test_input, index: [-33]})
|
||||
# MLIR bridge doesn't return the same error so it can't be matched
|
||||
# directly.
|
||||
if not test_util.is_mlir_bridge_enabled():
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||
sess.run(out, {a: test_input, index: [-33]})
|
||||
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||
sess.run(out, {a: test_input, index: [2]})
|
||||
with self.assertRaisesWithPredicateMatch(
|
||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||
sess.run(out, {a: test_input, index: [2]})
|
||||
|
||||
REAL_DATA = [
|
||||
np.zeros(shape=(2, 0)),
|
||||
@ -168,6 +172,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
|
||||
index_dtype)
|
||||
|
||||
@test_util.disable_mlir_bridge('Error messages differ')
|
||||
def testReduceSumWithDuplicateAxes(self, index_dtype):
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.compiler.tf2xla.python import xla
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -127,6 +128,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
||||
topk, [x.astype(dtype)],
|
||||
expected=[expected.astype(dtype), indices])
|
||||
|
||||
@test_util.disable_mlir_bridge("Support compare type in HLO Compare Op")
|
||||
def testTopKZeros(self):
|
||||
"""Tests that positive and negative zeros sort correctly."""
|
||||
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
||||
|
@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradients as gradient_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -211,6 +212,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
|
||||
|
||||
# test2x2NotInvertible is skipped as runtime error not raised for now.
|
||||
|
||||
@test_util.disable_mlir_bridge("Error messages differ")
|
||||
def testPartialPivotingRaises(self):
|
||||
np.random.seed(0)
|
||||
batch_size = 8
|
||||
|
Loading…
x
Reference in New Issue
Block a user