Whitelist InTopKV2, NextAfter and XlaKeyValueSort ops for the fallback path

Enabled relevant tests.

PiperOrigin-RevId: 335374607
Change-Id: I109c39459944648317c3a5274be4b5fe6c6e9586
This commit is contained in:
Smit Hinsu 2020-10-05 02:35:56 -07:00 committed by TensorFlower Gardener
parent 45654084ab
commit 07e0f88dd1
3 changed files with 5 additions and 1 deletions
tensorflow/compiler

View File

@ -151,6 +151,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::IgammaOp>(),
TypeID::get<TF::IgammacOp>(),
TypeID::get<TF::IgammaGradAOp>(),
TypeID::get<TF::InTopKV2Op>(),
TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(),
TypeID::get<TF::LRNOp>(),
@ -177,6 +178,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::MulOp>(),
TypeID::get<TF::MultinomialOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NextAfterOp>(),
TypeID::get<TF::NonMaxSuppressionV4Op>(),
TypeID::get<TF::NotEqualOp>(),
TypeID::get<TF::PadOp>(),
@ -241,6 +243,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::XlaDynamicSliceOp>(),
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
TypeID::get<TF::XlaEinsumOp>(),
TypeID::get<TF::XlaKeyValueSortOp>(),
TypeID::get<TF::XlaPadOp>(),
TypeID::get<TF::Xlog1pyOp>(),
TypeID::get<TF::XlogyOp>()

View File

@ -1089,6 +1089,7 @@ tf_xla_py_test(
name = "reduce_ops_test",
size = "medium",
srcs = ["reduce_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 5,
tags = [
@ -1545,6 +1546,7 @@ tf_xla_py_test(
name = "sort_ops_test",
size = "medium",
srcs = ["sort_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 1,
# Times out in fastbuild mode.

View File

@ -474,7 +474,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36],
dtype=np.int64))
@test_util.disable_mlir_bridge("Enable tf.NextAfter Compilation")
def testNextAfter(self):
for dtype in self.numeric_types:
if dtype in [np.float32, np.float64]: