Enable TridiagonalSolve and XlaSort op for fallback path lowering

Also, enable compiler/tests for reduce_ops_test

PiperOrigin-RevId: 338778787
Change-Id: I7db2ff2281f6404884fcf12204042a70b97a6782
This commit is contained in:
Smit Hinsu 2020-10-23 18:12:42 -07:00 committed by TensorFlower Gardener
parent e7b54fbda2
commit 62bc872eef
5 changed files with 21 additions and 7 deletions

View File

@ -232,6 +232,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::TensorScatterSubOp>(),
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
TypeID::get<TF::TransposeOp>(),
TypeID::get<TF::TridiagonalSolveOp>(),
TypeID::get<TF::TruncateDivOp>(),
TypeID::get<TF::TruncatedNormalOp>(),
TypeID::get<TF::TruncateModOp>(),
@ -247,7 +248,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::XlaKeyValueSortOp>(),
TypeID::get<TF::XlaPadOp>(),
TypeID::get<TF::Xlog1pyOp>(),
TypeID::get<TF::XlogyOp>()
TypeID::get<TF::XlogyOp>(),
TypeID::get<TF::XlaSortOp>()
};
// clang-format on

View File

@ -1087,6 +1087,7 @@ tf_xla_py_test(
name = "reduce_ops_test",
size = "medium",
srcs = ["reduce_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 5,
tags = [
@ -1546,6 +1547,7 @@ tf_xla_py_test(
name = "sort_ops_test",
size = "medium",
srcs = ["sort_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 1,
# Times out in fastbuild mode.
@ -1889,6 +1891,7 @@ tf_xla_py_test(
"cpu",
"gpu",
],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -27,6 +27,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@ -63,13 +64,16 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self.assertAllClose(
result, np_reduce_fn(test_input, axis=1), rtol=rtol, atol=atol)
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
sess.run(out, {a: test_input, index: [-33]})
# MLIR bridge doesn't return the same error so it can't be matched
# directly.
if not test_util.is_mlir_bridge_enabled():
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
sess.run(out, {a: test_input, index: [-33]})
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
sess.run(out, {a: test_input, index: [2]})
with self.assertRaisesWithPredicateMatch(
errors_impl.InvalidArgumentError, 'Invalid reduction dim'):
sess.run(out, {a: test_input, index: [2]})
REAL_DATA = [
np.zeros(shape=(2, 0)),
@ -168,6 +172,7 @@ class ReduceOpsTest(xla_test.XLATestCase, parameterized.TestCase):
self._testReduction(math_ops.reduce_any, np.any, np.bool, self.BOOL_DATA,
index_dtype)
@test_util.disable_mlir_bridge('Error messages differ')
def testReduceSumWithDuplicateAxes(self, index_dtype):
with self.session() as sess:
with self.test_scope():

View File

@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.compiler.tf2xla.python import xla
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
@ -127,6 +128,7 @@ class XlaSortOpTest(xla_test.XLATestCase):
topk, [x.astype(dtype)],
expected=[expected.astype(dtype), indices])
@test_util.disable_mlir_bridge("Support compare type in HLO Compare Op")
def testTopKZeros(self):
"""Tests that positive and negative zeros sort correctly."""
supported_types = set([dtypes.bfloat16.as_numpy_dtype, np.float32])

View File

@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients as gradient_ops
from tensorflow.python.ops import math_ops
@ -211,6 +212,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase):
# test2x2NotInvertible is skipped as runtime error not raised for now.
@test_util.disable_mlir_bridge("Error messages differ")
def testPartialPivotingRaises(self):
np.random.seed(0)
batch_size = 8