Fix data race in GatherOp::Compute().
Don't modify a member variable (batch_dims_) inside Compute(). PiperOrigin-RevId: 347665956 Change-Id: I77dcf4be15e3961d2808b2d0ca323ab1a8d7a3a8
This commit is contained in:
parent
2dd9bc663e
commit
e54e7a08c5
@ -88,29 +88,31 @@ class GatherOp : public OpKernel {
|
|||||||
axis = params.dims() + axis;
|
axis = params.dims() + axis;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch_dims_ != 0) {
|
// Modify only a local copy of batch_dims_.
|
||||||
OP_REQUIRES(
|
int32 batch_dims = batch_dims_;
|
||||||
c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
|
if (batch_dims != 0) {
|
||||||
|
OP_REQUIRES(c,
|
||||||
|
batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
|
||||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||||
-indices.dims(), ", ", indices.dims(),
|
-indices.dims(), ", ", indices.dims(),
|
||||||
"], but got ", batch_dims_));
|
"], but got ", batch_dims));
|
||||||
|
|
||||||
if (batch_dims_ < 0) {
|
if (batch_dims < 0) {
|
||||||
batch_dims_ = indices.dims() + batch_dims_;
|
batch_dims = indices.dims() + batch_dims;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!axis_is_set) axis = batch_dims_;
|
if (!axis_is_set) axis = batch_dims;
|
||||||
|
|
||||||
OP_REQUIRES(c, batch_dims_ < params.dims(),
|
OP_REQUIRES(c, batch_dims < params.dims(),
|
||||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
errors::InvalidArgument("batch_dims (", batch_dims,
|
||||||
") must be less than rank(params) (",
|
") must be less than rank(params) (",
|
||||||
params.dims(), ")."));
|
params.dims(), ")."));
|
||||||
|
|
||||||
OP_REQUIRES(c, axis >= batch_dims_,
|
OP_REQUIRES(c, axis >= batch_dims,
|
||||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
errors::InvalidArgument("batch_dims (", batch_dims,
|
||||||
") must be less than or equal to ",
|
") must be less than or equal to ",
|
||||||
"axis (", axis, ")."));
|
"axis (", axis, ")."));
|
||||||
for (int i = 0; i < batch_dims_; ++i) {
|
for (int i = 0; i < batch_dims; ++i) {
|
||||||
OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
|
OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"params.shape[", i, "]: ", params.dim_size(i),
|
"params.shape[", i, "]: ", params.dim_size(i),
|
||||||
@ -136,15 +138,15 @@ class GatherOp : public OpKernel {
|
|||||||
int64 outer_size = 1;
|
int64 outer_size = 1;
|
||||||
int64 inner_size = 1;
|
int64 inner_size = 1;
|
||||||
|
|
||||||
for (int i = 0; i < batch_dims_; ++i) {
|
for (int i = 0; i < batch_dims; ++i) {
|
||||||
result_shape.AddDim(params.dim_size(i));
|
result_shape.AddDim(params.dim_size(i));
|
||||||
batch_size *= params.dim_size(i);
|
batch_size *= params.dim_size(i);
|
||||||
}
|
}
|
||||||
for (int i = batch_dims_; i < axis; ++i) {
|
for (int i = batch_dims; i < axis; ++i) {
|
||||||
result_shape.AddDim(params.dim_size(i));
|
result_shape.AddDim(params.dim_size(i));
|
||||||
outer_size *= params.dim_size(i);
|
outer_size *= params.dim_size(i);
|
||||||
}
|
}
|
||||||
for (int i = batch_dims_; i < indices.dims(); ++i) {
|
for (int i = batch_dims; i < indices.dims(); ++i) {
|
||||||
result_shape.AddDim(indices.dim_size(i));
|
result_shape.AddDim(indices.dim_size(i));
|
||||||
}
|
}
|
||||||
for (int i = axis + 1; i < params.dims(); ++i) {
|
for (int i = axis + 1; i < params.dims(); ++i) {
|
||||||
@ -159,7 +161,7 @@ class GatherOp : public OpKernel {
|
|||||||
|
|
||||||
int64 bad_i = -1;
|
int64 bad_i = -1;
|
||||||
auto indices_flat = indices.flat<Index>();
|
auto indices_flat = indices.flat<Index>();
|
||||||
if (batch_dims_ > 0) {
|
if (batch_dims > 0) {
|
||||||
auto params_flat = params.shaped<T, 4>(
|
auto params_flat = params.shaped<T, 4>(
|
||||||
{batch_size, outer_size, gather_dim_size, inner_size});
|
{batch_size, outer_size, gather_dim_size, inner_size});
|
||||||
auto out_flat = out->shaped<T, 4>(
|
auto out_flat = out->shaped<T, 4>(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user