From 242e920d7fe7876f3eefa65252c4c857247c360c Mon Sep 17 00:00:00 2001
From: Smit Hinsu <hinsu@google.com>
Date: Fri, 13 Nov 2020 13:59:23 -0800
Subject: [PATCH] Enable fallback legalization for Erfinv and Ndtri ops

PiperOrigin-RevId: 342333062
Change-Id: I5bd57860564db162274c59983adf00bbacde288b
---
 .../mlir/xla/tests/legalize-tf-with-tf2xla.mlir    | 14 ++++++++++++++
 .../mlir/xla/transforms/legalize_tf_with_tf2xla.cc |  2 ++
 2 files changed, 16 insertions(+)

diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
index ec61cc06001..3b8170cf894 100644
--- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
+++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir
@@ -307,6 +307,20 @@ func @set_dynamic_dimension_size(%input: tensor<4xf32>, %size: tensor<i32>) -> t
   return %0 : tensor<4xf32>
 }
 
+// CHECK-LABEL: @erfinv
+func @erfinv(%input: tensor<4xf32>) -> tensor<4xf32> {
+  // CHECK-NOT: tf.Erfinv
+  %0 = "tf.Erfinv"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
+// CHECK-LABEL: @ndtri
+func @ndtri(%input: tensor<4xf32>) -> tensor<4xf32> {
+  // CHECK-NOT: tf.Ndtri
+  %0 = "tf.Ndtri"(%input) : (tensor<4xf32>) -> tensor<4xf32>
+  return %0 : tensor<4xf32>
+}
+
 
 // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
 // available but doesn't support this instance.
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index 51c1a16b6f8..763b4acc154 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -132,6 +132,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
     TypeID::get<TF::EluOp>(),
     TypeID::get<TF::EqualOp>(),
     TypeID::get<TF::ErfcOp>(),
+    TypeID::get<TF::ErfinvOp>(),
     TypeID::get<TF::ErfOp>(),
     TypeID::get<TF::Expm1Op>(),
     TypeID::get<TF::ExtractImagePatchesOp>(),
@@ -181,6 +182,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
     TypeID::get<TF::MirrorPadGradOp>(),
     TypeID::get<TF::MulOp>(),
     TypeID::get<TF::MultinomialOp>(),
+    TypeID::get<TF::NdtriOp>(),
     TypeID::get<TF::NegOp>(),
     TypeID::get<TF::NextAfterOp>(),
     TypeID::get<TF::NonMaxSuppressionV4Op>(),