Fix TensorFlow ScatterNd op lowering to HLO
* Use TensorScatterAdd to correctly handle repeated indices * Support Complex typed operands Also, enabled complex typed operands for ZerosLikeOp. PiperOrigin-RevId: 338347731 Change-Id: Ieade5166c3d8e234bda2f090bc636dc6c98931b1
This commit is contained in:
parent
f75608af78
commit
a309b4ee31
@ -641,7 +641,7 @@ func @Reciprocal_complexf64(%arg0: tensor<*xcomplex<f64>>) -> tensor<*xcomplex<f
|
|||||||
// CHECK-LABEL: @ScatterNd
|
// CHECK-LABEL: @ScatterNd
|
||||||
func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
|
func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> {
|
||||||
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
|
// CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32>
|
||||||
// CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
|
// CHECK: "tf.TensorScatterAdd"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32>
|
||||||
|
|
||||||
%shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
|
%shape = "tf.Const"() {value = dense<[8]> : tensor<1xi32>} : () -> tensor<1xi32>
|
||||||
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
|
%0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32>
|
||||||
|
@ -273,13 +273,14 @@ def CreateTFShapeOp : NativeCodeCall<
|
|||||||
|
|
||||||
// TODO(hinsu): Support inputs of TensorList types.
|
// TODO(hinsu): Support inputs of TensorList types.
|
||||||
def LowerZerosLikeOp :
|
def LowerZerosLikeOp :
|
||||||
Pat<(TF_ZerosLikeOp:$src_op TensorOf<[AnySignlessInteger, AnyFloat]>:$input),
|
Pat<(TF_ZerosLikeOp:$src_op
|
||||||
|
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$input),
|
||||||
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
|
(TF_BroadcastToOp (TF_ConstOp (GetScalarOfType<0> $input)),
|
||||||
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
|
(CreateTFShapeOp $src_op, $input, /*use 32bit*/ConstBoolAttrFalse))>;
|
||||||
|
|
||||||
def LowerScatterNdOp :
|
def LowerScatterNdOp :
|
||||||
Pat<(TF_ScatterNdOp $indices,
|
Pat<(TF_ScatterNdOp $indices,
|
||||||
TensorOf<[AnySignlessInteger, AnyFloat]>:$updates, $shape),
|
TensorOf<[AnyInteger, AnyFloat, AnyComplex]>:$updates, $shape),
|
||||||
(TF_TensorScatterUpdateOp
|
(TF_TensorScatterAddOp
|
||||||
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
(TF_FillOp $shape, (TF_ConstOp (GetScalarOfType<0> $updates))),
|
||||||
$indices, $updates)>;
|
$indices, $updates)>;
|
||||||
|
@ -228,6 +228,8 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) {
|
|||||||
TypeID::get<TF::StatelessTruncatedNormalOp>(),
|
TypeID::get<TF::StatelessTruncatedNormalOp>(),
|
||||||
TypeID::get<TF::SubOp>(),
|
TypeID::get<TF::SubOp>(),
|
||||||
TypeID::get<TF::TanOp>(),
|
TypeID::get<TF::TanOp>(),
|
||||||
|
TypeID::get<TF::TensorScatterAddOp>(),
|
||||||
|
TypeID::get<TF::TensorScatterSubOp>(),
|
||||||
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
TypeID::get<TF::TPUEmbeddingActivationsOp>(),
|
||||||
TypeID::get<TF::TransposeOp>(),
|
TypeID::get<TF::TransposeOp>(),
|
||||||
TypeID::get<TF::TruncateDivOp>(),
|
TypeID::get<TF::TruncateDivOp>(),
|
||||||
|
@ -1524,6 +1524,7 @@ tf_xla_py_test(
|
|||||||
name = "scatter_nd_op_test",
|
name = "scatter_nd_op_test",
|
||||||
size = "medium",
|
size = "medium",
|
||||||
srcs = ["scatter_nd_op_test.py"],
|
srcs = ["scatter_nd_op_test.py"],
|
||||||
|
enable_mlir_bridge = True,
|
||||||
python_version = "PY3",
|
python_version = "PY3",
|
||||||
tags = [
|
tags = [
|
||||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||||
|
@ -24,6 +24,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.compiler.tests import xla_test
|
from tensorflow.compiler.tests import xla_test
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
@ -161,6 +162,7 @@ class ScatterNdTest(xla_test.XLATestCase):
|
|||||||
expected = np.zeros([2, 2], dtype=np.int32)
|
expected = np.zeros([2, 2], dtype=np.int32)
|
||||||
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
|
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Error messages differ")
|
||||||
def testRank3InvalidShape1(self):
|
def testRank3InvalidShape1(self):
|
||||||
indices = np.zeros([3, 2, 2], np.int32)
|
indices = np.zeros([3, 2, 2], np.int32)
|
||||||
updates = np.zeros([2, 2, 2], np.int32)
|
updates = np.zeros([2, 2, 2], np.int32)
|
||||||
@ -168,6 +170,7 @@ class ScatterNdTest(xla_test.XLATestCase):
|
|||||||
"Must have updates.shape"):
|
"Must have updates.shape"):
|
||||||
self._runScatterNd(indices, updates, [2, 2, 2])
|
self._runScatterNd(indices, updates, [2, 2, 2])
|
||||||
|
|
||||||
|
@test_util.disable_mlir_bridge("Error messages differ")
|
||||||
def testRank3InvalidShape2(self):
|
def testRank3InvalidShape2(self):
|
||||||
indices = np.zeros([2, 2, 1], np.int32)
|
indices = np.zeros([2, 2, 1], np.int32)
|
||||||
updates = np.zeros([2, 2], np.int32)
|
updates = np.zeros([2, 2], np.int32)
|
||||||
|
Loading…
Reference in New Issue
Block a user