Fix data race in GatherOp::Compute().

Don't modify a member variable (batch_dims_) inside Compute().

PiperOrigin-RevId: 347665956
Change-Id: I77dcf4be15e3961d2808b2d0ca323ab1a8d7a3a8
This commit is contained in:
A. Unique TensorFlower 2020-12-15 12:12:38 -08:00 committed by TensorFlower Gardener
parent 2dd9bc663e
commit e54e7a08c5

View File

@ -88,29 +88,31 @@ class GatherOp : public OpKernel {
axis = params.dims() + axis;
}
if (batch_dims_ != 0) {
OP_REQUIRES(
c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices.dims(), ", ", indices.dims(),
"], but got ", batch_dims_));
// Modify only a local copy of batch_dims_.
int32 batch_dims = batch_dims_;
if (batch_dims != 0) {
OP_REQUIRES(c,
batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices.dims(), ", ", indices.dims(),
"], but got ", batch_dims));
if (batch_dims_ < 0) {
batch_dims_ = indices.dims() + batch_dims_;
if (batch_dims < 0) {
batch_dims = indices.dims() + batch_dims;
}
if (!axis_is_set) axis = batch_dims_;
if (!axis_is_set) axis = batch_dims;
OP_REQUIRES(c, batch_dims_ < params.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_,
OP_REQUIRES(c, batch_dims < params.dims(),
errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than rank(params) (",
params.dims(), ")."));
OP_REQUIRES(c, axis >= batch_dims_,
errors::InvalidArgument("batch_dims (", batch_dims_,
OP_REQUIRES(c, axis >= batch_dims,
errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than or equal to ",
"axis (", axis, ")."));
for (int i = 0; i < batch_dims_; ++i) {
for (int i = 0; i < batch_dims; ++i) {
OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
errors::InvalidArgument(
"params.shape[", i, "]: ", params.dim_size(i),
@ -136,15 +138,15 @@ class GatherOp : public OpKernel {
int64 outer_size = 1;
int64 inner_size = 1;
for (int i = 0; i < batch_dims_; ++i) {
for (int i = 0; i < batch_dims; ++i) {
result_shape.AddDim(params.dim_size(i));
batch_size *= params.dim_size(i);
}
for (int i = batch_dims_; i < axis; ++i) {
for (int i = batch_dims; i < axis; ++i) {
result_shape.AddDim(params.dim_size(i));
outer_size *= params.dim_size(i);
}
for (int i = batch_dims_; i < indices.dims(); ++i) {
for (int i = batch_dims; i < indices.dims(); ++i) {
result_shape.AddDim(indices.dim_size(i));
}
for (int i = axis + 1; i < params.dims(); ++i) {
@ -159,7 +161,7 @@ class GatherOp : public OpKernel {
int64 bad_i = -1;
auto indices_flat = indices.flat<Index>();
if (batch_dims_ > 0) {
if (batch_dims > 0) {
auto params_flat = params.shaped<T, 4>(
{batch_size, outer_size, gather_dim_size, inner_size});
auto out_flat = out->shaped<T, 4>(