From a309b4ee31a36a884e2b3634b3c0bfbaea784c98 Mon Sep 17 00:00:00 2001 From: Smit Hinsu <hinsu@google.com> Date: Wed, 21 Oct 2020 15:02:08 -0700 Subject: [PATCH] Fix TensorFlow ScatterNd op lowering to HLO * Use TensorScatterAdd to correctly handle repeated indices * Support Complex typed operands Also, enabled complex typed operands for ZerosLikeOp. PiperOrigin-RevId: 338347731 Change-Id: Ieade5166c3d8e234bda2f090bc636dc6c98931b1 --- tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir | 2 +- tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td | 7 ++++--- .../mlir/xla/transforms/legalize_tf_with_tf2xla.cc | 2 ++ tensorflow/compiler/tests/BUILD | 1 + tensorflow/compiler/tests/scatter_nd_op_test.py | 3 +++ 5 files changed, 11 insertions(+), 4 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index 5230f05f25c..fcd2f2512fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -641,7 +641,7 @@ func @Reciprocal_complexf64(%arg0: tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f // CHECK-LABEL: @ScatterNd func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> { // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32> - // CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32> + // CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32> %shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32> %0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index bddc863ee60..fec4c20e98d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -273,13 +273,14 @@ def CreateTFShapeOp : NativeCodeCall< // TODO(hinsu): Support inputs of TensorList types. def LowerZerosLikeOp : - Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input), + Pat<(TF_ZerosLikeOp:$src_op + TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input), (TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)), (CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>; def LowerScatterNdOp : Pat<(TF_ScatterNdOp $indices, - TensorOf<[AnySignlessInteger, AnyFloat]>:$updates, $shape), - (TF_TensorScatterUpdateOp + TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape), + (TF_TensorScatterAddOp (TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))), $indices, $updates)>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index b392e91e22f..5098e581fd6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -228,6 +228,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get<TF::StatelessTruncatedNormalOp>(), TypeID::get<TF::SubOp>(), TypeID::get<TF::TanOp>(), + TypeID::get<TF::TensorScatterAddOp>(), + TypeID::get<TF::TensorScatterSubOp>(), TypeID::get<TF::TPUEmbeddingActivationsOp>(), TypeID::get<TF::TransposeOp>(), TypeID::get<TF::TruncateDivOp>(), diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index eb0cde57591..3f058fe0773 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -1524,6 +1524,7 @@ tf_xla_py_test( name = "scatter_nd_op_test", size = "medium", srcs = ["scatter_nd_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py index 3adb169e7f0..04531108b70 100644 --- a/tensorflow/compiler/tests/scatter_nd_op_test.py +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -24,6 +24,7 @@ import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import errors +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -161,6 +162,7 @@ class ScatterNdTest(xla_test.XLATestCase): expected = np.zeros([2, 2], dtype=np.int32) self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2])) + @test_util.disable_mlir_bridge("Error messages differ") def testRank3InvalidShape1(self): indices = np.zeros([3, 2, 2], np.int32) updates = np.zeros([2, 2, 2], np.int32) @@ -168,6 +170,7 @@ class ScatterNdTest(xla_test.XLATestCase): "Must have updates.shape"): self._runScatterNd(indices, updates, [2, 2, 2]) + @test_util.disable_mlir_bridge("Error messages differ") def testRank3InvalidShape2(self): indices = np.zeros([2, 2, 1], np.int32) updates = np.zeros([2, 2], np.int32)