diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 320793e0e12..e8a3dab4bed 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -257,7 +257,7 @@ class GatherOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); // The number of batch dimensions, as passed in the batch_dims attribute. - // It must be less than rank(indices). + // It must be less than or equal to rank(indices). int32 batch_dims_ = 0; }; diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 4e051c6b186..849a2b4389f 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -182,7 +182,7 @@ class GatherOp : public OpKernel { private: // The number of batch dimensions, as passed in the batch_dims attribute. - // It must be less than rank(indices). + // It must be less than or equal to rank(indices). int32 batch_dims_ = 0; };