From a309b4ee31a36a884e2b3634b3c0bfbaea784c98 Mon Sep 17 00:00:00 2001
From: Smit Hinsu <hinsu@google.com>
Date: Wed, 21 Oct 2020 15:02:08 -0700
Subject: [PATCH] Fix TensorFlow ScatterNd op lowering to HLO

* Use TensorScatterAdd to correctly handle repeated indices
* Support Complex typed operands

Also, enabled complex typed operands for ZerosLikeOp.

PiperOrigin-RevId: 338347731
Change-Id: Ieade5166c3d8e234bda2f090bc636dc6c98931b1
---
 tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir    | 2 +-
 tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td | 7 ++++---
 .../mlir/xla/transforms/legalize_tf_with_tf2xla.cc         | 2 ++
 tensorflow/compiler/tests/BUILD                            | 1 +
 tensorflow/compiler/tests/scatter_nd_op_test.py            | 3 +++
 5 files changed, 11 insertions(+), 4 deletions(-)

diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
index 5230f05f25c..fcd2f2512fd 100644
--- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
+++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir
@@ -641,7 +641,7 @@ func @Reciprocal_complexf64(%arg0: tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f
 // CHECK-LABEL: @ScatterNd
 func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
   // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
-  // CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
+  // CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
 
   %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
   %0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
index bddc863ee60..fec4c20e98d 100644
--- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
+++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td
@@ -273,13 +273,14 @@ def CreateTFShapeOp : NativeCodeCall<
 
 // TODO(hinsu): Support inputs of TensorList types.
 def LowerZerosLikeOp :
-  Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input),
+  Pat<(TF_ZerosLikeOp:$src_op
+       TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
       (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
                         (CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
 
 def LowerScatterNdOp :
   Pat<(TF_ScatterNdOp $indices,
-       TensorOf<[AnySignlessInteger, AnyFloat]>:$updates, $shape),
-      (TF_TensorScatterUpdateOp
+       TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape),
+      (TF_TensorScatterAddOp
        (TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
        $indices, $updates)>;
diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
index b392e91e22f..5098e581fd6 100644
--- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
+++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc
@@ -228,6 +228,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
     TypeID::get<TF::StatelessTruncatedNormalOp>(),
     TypeID::get<TF::SubOp>(),
     TypeID::get<TF::TanOp>(),
+    TypeID::get<TF::TensorScatterAddOp>(),
+    TypeID::get<TF::TensorScatterSubOp>(),
     TypeID::get<TF::TPUEmbeddingActivationsOp>(),
     TypeID::get<TF::TransposeOp>(),
     TypeID::get<TF::TruncateDivOp>(),
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index eb0cde57591..3f058fe0773 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1524,6 +1524,7 @@ tf_xla_py_test(
     name = "scatter_nd_op_test",
     size = "medium",
     srcs = ["scatter_nd_op_test.py"],
+    enable_mlir_bridge = True,
     python_version = "PY3",
     tags = [
         "no_pip",  # TODO(b/149738646): fix pip install so these tests run on kokoro pip
diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py
index 3adb169e7f0..04531108b70 100644
--- a/tensorflow/compiler/tests/scatter_nd_op_test.py
+++ b/tensorflow/compiler/tests/scatter_nd_op_test.py
@@ -24,6 +24,7 @@ import numpy as np
 
 from tensorflow.compiler.tests import xla_test
 from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.platform import test
 
@@ -161,6 +162,7 @@ class ScatterNdTest(xla_test.XLATestCase):
     expected = np.zeros([2, 2], dtype=np.int32)
     self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
 
+  @test_util.disable_mlir_bridge("Error messages differ")
   def testRank3InvalidShape1(self):
     indices = np.zeros([3, 2, 2], np.int32)
     updates = np.zeros([2, 2, 2], np.int32)
@@ -168,6 +170,7 @@ class ScatterNdTest(xla_test.XLATestCase):
                                              "Must have updates.shape"):
       self._runScatterNd(indices, updates, [2, 2, 2])
 
+  @test_util.disable_mlir_bridge("Error messages differ")
   def testRank3InvalidShape2(self):
     indices = np.zeros([2, 2, 1], np.int32)
     updates = np.zeros([2, 2], np.int32)