diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index 2f895b17219..f73d2b109a1 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -762,7 +762,7 @@ REGISTER_OP("XlaScatter") .Attr("T: numbertype") .Attr("Tindices: {int32, int64}") .Output("output: T") - .SetShapeFn(UnchangedRank) + .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( Wraps the XLA Scatter operator documented at https://www.tensorflow.org/xla/operation_semantics#scatter.