Enable TridiagonalSolve and XlaSort op for fallback path lowering
Also, enable compiler/tests for reduce_ops_test PiperOrigin-RevId: 338778787 Change-Id: I7db2ff2281f6404884fcf12204042a70b97a6782
This commit is contained in:
parent
e7b54fbda2
commit
62bc872eef
@ -232,6 +232,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::TensorScatterSubOp>(),
|
TypeID::get<TF::TensorScatterSubOp>(),
|
||||||
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
||||||
TypeID::get<TF::TransposeOp>(),
|
TypeID::get<TF::TransposeOp>(),
|
||||||
|
TypeID::get<TF::TridiagonalSolveOp>(),
|
||||||
TypeID::get<TF::TruncateDivOp>(),
|
TypeID::get<TF::TruncateDivOp>(),
|
||||||
TypeID::get<TF::TruncatedNormalOp>(),
|
TypeID::get<TF::TruncatedNormalOp>(),
|
||||||
TypeID::get<TF::TruncateModOp>(),
|
TypeID::get<TF::TruncateModOp>(),
|
||||||
@ -247,7 +248,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::XlaKeyValueSortOp>(),
|
TypeID::get<TF::XlaKeyValueSortOp>(),
|
||||||
TypeID::get<TF::XlaPadOp>(),
|
TypeID::get<TF::XlaPadOp>(),
|
||||||
TypeID::get<TF::Xlog1pyOp>(),
|
TypeID::get<TF::Xlog1pyOp>(),
|
||||||
TypeID::get<TF::XlogyOp>()
|
TypeID::get<TF::XlogyOp>(),
|
||||||
|
TypeID::get<TF::XlaSortOp>()
|
||||||
};
|
};
|
||||||
// clang-format on
|
// clang-format on
|
||||||
|
|
||||||
|
@ -1087,6 +1087,7 @@ tf_xla_py_test(
|
|||||||
name = "reduce_ops_test",
|
name = "reduce_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["reduce_ops_test.py"],
|
srcs = ["reduce_ops_test.py"],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 5,
|
shard_count = 5,
|
||||||
tags = [
|
tags = [
|
||||||
@ -1546,6 +1547,7 @@ tf_xla_py_test(
|
|||||||
name = "sort_ops_test",
|
name = "sort_ops_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["sort_ops_test.py"],
|
srcs = ["sort_ops_test.py"],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
shard_count = 1,
|
shard_count = 1,
|
||||||
# Times out in fastbuild mode.
|
# Times out in fastbuild mode.
|
||||||
@ -1889,6 +1891,7 @@ tf_xla_py_test(
|
|||||||
"cpu",
|
"cpu",
|
||||||
"gpu",
|
"gpu",
|
||||||
],
|
],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
|
@ -27,6 +27,7 @@ import numpy as np
|
|||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
@ -63,13 +64,16 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
self.assertAllClose(
|
self.assertAllClose(
|
||||||
result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
|
result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
|
||||||
|
|
||||||
with self.assertRaisesWithPredicateMatch(
|
# MLIR bridge doesn't return the same error so it can't be matched
|
||||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
# directly.
|
||||||
sess.run(out, {a: test_input, index: [-33]})
|
if not test_util.is_mlir_bridge_enabled():
|
||||||
|
with self.assertRaisesWithPredicateMatch(
|
||||||
|
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||||
|
sess.run(out, {a: test_input, index: [-33]})
|
||||||
|
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesWithPredicateMatch(
|
||||||
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
|
||||||
sess.run(out, {a: test_input, index: [2]})
|
sess.run(out, {a: test_input, index: [2]})
|
||||||
|
|
||||||
REAL_DATA = [
|
REAL_DATA = [
|
||||||
np.zeros(shape=(2, 0)),
|
np.zeros(shape=(2, 0)),
|
||||||
@ -168,6 +172,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
|
|||||||
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
|
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
|
||||||
index_dtype)
|
index_dtype)
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge('Error messages differ')
|
||||||
def testReduceSumWithDuplicateAxes(self, index_dtype):
|
def testReduceSumWithDuplicateAxes(self, index_dtype):
|
||||||
with self.session() as sess:
|
with self.session() as sess:
|
||||||
with self.test_scope():
|
with self.test_scope():
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
|
|||||||
from tensorflow.compiler.tf2xla.python import xla
|
from tensorflow.compiler.tf2xla.python import xla
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -127,6 +128,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
|
|||||||
topk, [x.astype(dtype)],
|
topk, [x.astype(dtype)],
|
||||||
expected=[expected.astype(dtype), indices])
|
expected=[expected.astype(dtype), indices])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Support compare type in HLO Compare Op")
|
||||||
def testTopKZeros(self):
|
def testTopKZeros(self):
|
||||||
"""Tests that positive and negative zeros sort correctly."""
|
"""Tests that positive and negative zeros sort correctly."""
|
||||||
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])
|
||||||
|
@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradients as gradient_ops
|
from tensorflow.python.ops import gradients as gradient_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -211,6 +212,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
|
|||||||
|
|
||||||
# test2x2NotInvertible is skipped as runtime error not raised for now.
|
# test2x2NotInvertible is skipped as runtime error not raised for now.
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Error messages differ")
|
||||||
def testPartialPivotingRaises(self):
|
def testPartialPivotingRaises(self):
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
batch_size = 8
|
batch_size = 8
|
||||||
|
Loading…
x
Reference in New Issue
Block a user