Whitelist InTopKV2, NextAfter and XlaKeyValueSort ops for the fallback path
Enabled relevant tests. PiperOrigin-RevId: 335374607 Change-Id: I109c39459944648317c3a5274be4b5fe6c6e9586
This commit is contained in:
parent
45654084ab
commit
07e0f88dd1
tensorflow/compiler
@ -151,6 +151,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::IgammaOp>(),
|
||||
TypeID::get<TF::IgammacOp>(),
|
||||
TypeID::get<TF::IgammaGradAOp>(),
|
||||
TypeID::get<TF::InTopKV2Op>(),
|
||||
TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::LRNOp>(),
|
||||
@ -177,6 +178,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::MulOp>(),
|
||||
TypeID::get<TF::MultinomialOp>(),
|
||||
TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NextAfterOp>(),
|
||||
TypeID::get<TF::NonMaxSuppressionV4Op>(),
|
||||
TypeID::get<TF::NotEqualOp>(),
|
||||
TypeID::get<TF::PadOp>(),
|
||||
@ -241,6 +243,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::XlaDynamicSliceOp>(),
|
||||
TypeID::get<TF::XlaDynamicUpdateSliceOp>(),
|
||||
TypeID::get<TF::XlaEinsumOp>(),
|
||||
TypeID::get<TF::XlaKeyValueSortOp>(),
|
||||
TypeID::get<TF::XlaPadOp>(),
|
||||
TypeID::get<TF::Xlog1pyOp>(),
|
||||
TypeID::get<TF::XlogyOp>()
|
||||
|
@ -1089,6 +1089,7 @@ tf_xla_py_test(
|
||||
name = "reduce_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["reduce_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 5,
|
||||
tags = [
|
||||
@ -1545,6 +1546,7 @@ tf_xla_py_test(
|
||||
name = "sort_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["sort_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 1,
|
||||
# Times out in fastbuild mode.
|
||||
|
@ -474,7 +474,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
expected=np.array([1 << 32, 1 << 36, 1 << 32, 1 << 36],
|
||||
dtype=np.int64))
|
||||
|
||||
@test_util.disable_mlir_bridge("Enable tf.NextAfter Compilation")
|
||||
def testNextAfter(self):
|
||||
for dtype in self.numeric_types:
|
||||
if dtype in [np.float32, np.float64]:
|
||||
|
Loading…
Reference in New Issue
Block a user