From 242e920d7fe7876f3eefa65252c4c857247c360c Mon Sep 17 00:00:00 2001 From: Smit Hinsu <hinsu@google.com> Date: Fri, 13 Nov 2020 13:59:23 -0800 Subject: [PATCH] Enable fallback legalization for Erfinv and Ndtri ops PiperOrigin-RevId: 342333062 Change-Id: I5bd57860564db162274c59983adf00bbacde288b --- .../mlir/xla/tests/legalize-tf-with-tf2xla.mlir | 14 ++++++++++++++ .../mlir/xla/transforms/legalize_tf_with_tf2xla.cc | 2 ++ 2 files changed, 16 insertions(+) diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index ec61cc06001..3b8170cf894 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -307,6 +307,20 @@ func @set_dynamic_dimension_size(%input: tensor<4xf32>, %size: tensor<i32>) -> t return %0 : tensor<4xf32> } +// CHECK-LABEL: @erfinv +func @erfinv(%input: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NOT: tf.Erfinv + %0 = "tf.Erfinv"(%input) : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + +// CHECK-LABEL: @ndtri +func @ndtri(%input: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NOT: tf.Ndtri + %0 = "tf.Ndtri"(%input) : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 51c1a16b6f8..763b4acc154 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -132,6 +132,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get<TF::EluOp>(), TypeID::get<TF::EqualOp>(), TypeID::get<TF::ErfcOp>(), + TypeID::get<TF::ErfinvOp>(), TypeID::get<TF::ErfOp>(), TypeID::get<TF::Expm1Op>(), TypeID::get<TF::ExtractImagePatchesOp>(), @@ -181,6 +182,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get<TF::MirrorPadGradOp>(), TypeID::get<TF::MulOp>(), TypeID::get<TF::MultinomialOp>(), + TypeID::get<TF::NdtriOp>(), TypeID::get<TF::NegOp>(), TypeID::get<TF::NextAfterOp>(), TypeID::get<TF::NonMaxSuppressionV4Op>(),