Fix TensorFlow ScatterNd op lowering to HLO
* Use TensorScatterAdd to correctly handle repeated indices * Support Complex typed operands Also, enabled complex typed operands for ZerosLikeOp. PiperOrigin-RevId: 338347731 Change-Id: Ieade5166c3d8e234bda2f090bc636dc6c98931b1
This commit is contained in:
parent
f75608af78
commit
a309b4ee31
@ -641,7 +641,7 @@ func @Reciprocal_complexf64(%arg0: tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f
|
||||
// CHECK-LABEL: @ScatterNd
|
||||
func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
|
||||
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
|
||||
// CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
|
||||
// CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
|
||||
|
||||
%shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
|
||||
|
@ -273,13 +273,14 @@ def CreateTFShapeOp : NativeCodeCall<
|
||||
|
||||
// TODO(hinsu): Support inputs of TensorList types.
|
||||
def LowerZerosLikeOp :
|
||||
Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input),
|
||||
Pat<(TF_ZerosLikeOp:$src_op
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
|
||||
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
|
||||
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
|
||||
|
||||
def LowerScatterNdOp :
|
||||
Pat<(TF_ScatterNdOp $indices,
|
||||
TensorOf<[AnySignlessInteger, AnyFloat]>:$updates, $shape),
|
||||
(TF_TensorScatterUpdateOp
|
||||
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape),
|
||||
(TF_TensorScatterAddOp
|
||||
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
||||
$indices, $updates)>;
|
||||
|
@ -228,6 +228,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
||||
TypeID::get<TF::StatelessTruncatedNormalOp>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
TypeID::get<TF::TanOp>(),
|
||||
TypeID::get<TF::TensorScatterAddOp>(),
|
||||
TypeID::get<TF::TensorScatterSubOp>(),
|
||||
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
||||
TypeID::get<TF::TransposeOp>(),
|
||||
TypeID::get<TF::TruncateDivOp>(),
|
||||
|
@ -1524,6 +1524,7 @@ tf_xla_py_test(
|
||||
name = "scatter_nd_op_test",
|
||||
size = "medium",
|
||||
srcs = ["scatter_nd_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -161,6 +162,7 @@ class ScatterNdTest(xla_test.XLATestCase):
|
||||
expected = np.zeros([2, 2], dtype=np.int32)
|
||||
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
|
||||
|
||||
@test_util.disable_mlir_bridge("Error messages differ")
|
||||
def testRank3InvalidShape1(self):
|
||||
indices = np.zeros([3, 2, 2], np.int32)
|
||||
updates = np.zeros([2, 2, 2], np.int32)
|
||||
@ -168,6 +170,7 @@ class ScatterNdTest(xla_test.XLATestCase):
|
||||
"Must have updates.shape"):
|
||||
self._runScatterNd(indices, updates, [2, 2, 2])
|
||||
|
||||
@test_util.disable_mlir_bridge("Error messages differ")
|
||||
def testRank3InvalidShape2(self):
|
||||
indices = np.zeros([2, 2, 1], np.int32)
|
||||
updates = np.zeros([2, 2], np.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user