Fix data race in GatherOp::Compute().

Don't modify a member variable (batch_dims_) inside Compute().

PiperOrigin-RevId: 347665956
Change-Id: I77dcf4be15e3961d2808b2d0ca323ab1a8d7a3a8
This commit is contained in:
A. Unique TensorFlower 2020-12-15 12:12:38 -08:00 committed by TensorFlower Gardener
parent 2dd9bc663e
commit e54e7a08c5

View File

@ -88,29 +88,31 @@ class GatherOp : public OpKernel {
axis = params.dims() + axis; axis = params.dims() + axis;
} }
if (batch_dims_ != 0) { // Modify only a local copy of batch_dims_.
OP_REQUIRES( int32 batch_dims = batch_dims_;
c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(), if (batch_dims != 0) {
OP_REQUIRES(c,
batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
errors::InvalidArgument("Expected batch_dims in the range [", errors::InvalidArgument("Expected batch_dims in the range [",
-indices.dims(), ", ", indices.dims(), -indices.dims(), ", ", indices.dims(),
"], but got ", batch_dims_)); "], but got ", batch_dims));
if (batch_dims_ < 0) { if (batch_dims < 0) {
batch_dims_ = indices.dims() + batch_dims_; batch_dims = indices.dims() + batch_dims;
} }
if (!axis_is_set) axis = batch_dims_; if (!axis_is_set) axis = batch_dims;
OP_REQUIRES(c, batch_dims_ < params.dims(), OP_REQUIRES(c, batch_dims < params.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_, errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than rank(params) (", ") must be less than rank(params) (",
params.dims(), ").")); params.dims(), ")."));
OP_REQUIRES(c, axis >= batch_dims_, OP_REQUIRES(c, axis >= batch_dims,
errors::InvalidArgument("batch_dims (", batch_dims_, errors::InvalidArgument("batch_dims (", batch_dims,
") must be less than or equal to ", ") must be less than or equal to ",
"axis (", axis, ").")); "axis (", axis, ")."));
for (int i = 0; i < batch_dims_; ++i) { for (int i = 0; i < batch_dims; ++i) {
OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i), OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
errors::InvalidArgument( errors::InvalidArgument(
"params.shape[", i, "]: ", params.dim_size(i), "params.shape[", i, "]: ", params.dim_size(i),
@ -136,15 +138,15 @@ class GatherOp : public OpKernel {
int64 outer_size = 1; int64 outer_size = 1;
int64 inner_size = 1; int64 inner_size = 1;
for (int i = 0; i < batch_dims_; ++i) { for (int i = 0; i < batch_dims; ++i) {
result_shape.AddDim(params.dim_size(i)); result_shape.AddDim(params.dim_size(i));
batch_size *= params.dim_size(i); batch_size *= params.dim_size(i);
} }
for (int i = batch_dims_; i < axis; ++i) { for (int i = batch_dims; i < axis; ++i) {
result_shape.AddDim(params.dim_size(i)); result_shape.AddDim(params.dim_size(i));
outer_size *= params.dim_size(i); outer_size *= params.dim_size(i);
} }
for (int i = batch_dims_; i < indices.dims(); ++i) { for (int i = batch_dims; i < indices.dims(); ++i) {
result_shape.AddDim(indices.dim_size(i)); result_shape.AddDim(indices.dim_size(i));
} }
for (int i = axis + 1; i < params.dims(); ++i) { for (int i = axis + 1; i < params.dims(); ++i) {
@ -159,7 +161,7 @@ class GatherOp : public OpKernel {
int64 bad_i = -1; int64 bad_i = -1;
auto indices_flat = indices.flat<Index>(); auto indices_flat = indices.flat<Index>();
if (batch_dims_ > 0) { if (batch_dims > 0) {
auto params_flat = params.shaped<T, 4>( auto params_flat = params.shaped<T, 4>(
{batch_size, outer_size, gather_dim_size, inner_size}); {batch_size, outer_size, gather_dim_size, inner_size});
auto out_flat = out->shaped<T, 4>( auto out_flat = out->shaped<T, 4>(