Fix TensorFlow ScatterNd op lowering to HLO

* Use TensorScatterAdd to correctly handle repeated indices
* Support Complex typed operands

Also, enabled complex typed operands for ZerosLikeOp.

PiperOrigin-RevId: 338347731
Change-Id: Ieade5166c3d8e234bda2f090bc636dc6c98931b1
This commit is contained in:
Smit Hinsu 2020-10-21 15:02:08 -07:00 committed by TensorFlower Gardener
parent f75608af78
commit a309b4ee31
5 changed files with 11 additions and 4 deletions

View File

@ -641,7 +641,7 @@ func @Reciprocal_complexf64(%arg0: tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f
// CHECK-LABEL: @ScatterNd
func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
// CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
// CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
%shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>

View File

@ -273,13 +273,14 @@ def CreateTFShapeOp : NativeCodeCall<
// TODO(hinsu): Support inputs of TensorList types.
def LowerZerosLikeOp :
Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input),
Pat<(TF_ZerosLikeOp:$src_op
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
def LowerScatterNdOp :
Pat<(TF_ScatterNdOp $indices,
TensorOf<[AnySignlessInteger, AnyFloat]>:$updates, $shape),
(TF_TensorScatterUpdateOp
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape),
(TF_TensorScatterAddOp
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
$indices, $updates)>;

View File

@ -228,6 +228,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
TypeID::get<TF::StatelessTruncatedNormalOp>(),
TypeID::get<TF::SubOp>(),
TypeID::get<TF::TanOp>(),
TypeID::get<TF::TensorScatterAddOp>(),
TypeID::get<TF::TensorScatterSubOp>(),
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
TypeID::get<TF::TransposeOp>(),
TypeID::get<TF::TruncateDivOp>(),

View File

@ -1524,6 +1524,7 @@ tf_xla_py_test(
name = "scatter_nd_op_test",
size = "medium",
srcs = ["scatter_nd_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -161,6 +162,7 @@ class ScatterNdTest(xla_test.XLATestCase):
expected = np.zeros([2, 2], dtype=np.int32)
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
@test_util.disable_mlir_bridge("Error messages differ")
def testRank3InvalidShape1(self):
indices = np.zeros([3, 2, 2], np.int32)
updates = np.zeros([2, 2, 2], np.int32)
@ -168,6 +170,7 @@ class ScatterNdTest(xla_test.XLATestCase):
"Must have updates.shape"):
self._runScatterNd(indices, updates, [2, 2, 2])
@test_util.disable_mlir_bridge("Error messages differ")
def testRank3InvalidShape2(self):
indices = np.zeros([2, 2, 1], np.int32)
updates = np.zeros([2, 2], np.int32)