Add shape inference function for XlaScatter
PiperOrigin-RevId: 334603580 Change-Id: Idc193c32dd429cdf7f14a8496c6340c4e7b803b4
This commit is contained in:
parent
1c9964097d
commit
fa75523767
@ -762,7 +762,7 @@ REGISTER_OP("XlaScatter")
|
|||||||
.Attr("T: numbertype")
|
.Attr("T: numbertype")
|
||||||
.Attr("Tindices: {int32, int64}")
|
.Attr("Tindices: {int32, int64}")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.SetShapeFn(UnchangedRank)
|
.SetShapeFn(shape_inference::UnchangedShape)
|
||||||
.Doc(R"doc(
|
.Doc(R"doc(
|
||||||
Wraps the XLA Scatter operator documented at
|
Wraps the XLA Scatter operator documented at
|
||||||
https://www.tensorflow.org/xla/operation_semantics#scatter.
|
https://www.tensorflow.org/xla/operation_semantics#scatter.
|
||||||
|
Loading…
Reference in New Issue
Block a user