Add shape inference function for XlaScatter

PiperOrigin-RevId: 334603580
Change-Id: Idc193c32dd429cdf7f14a8496c6340c4e7b803b4
This commit is contained in:
A. Unique TensorFlower 2020-09-30 08:26:55 -07:00 committed by TensorFlower Gardener
parent 1c9964097d
commit fa75523767

View File

@ -762,7 +762,7 @@ REGISTER_OP("XlaScatter")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Output("output: T")
.SetShapeFn(UnchangedRank)
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Wraps the XLA Scatter operator documented at
https://www.tensorflow.org/xla/operation_semantics#scatter.