Fix data race in GatherOp::Compute().
Don't modify a member variable (batch_dims_) inside Compute(). PiperOrigin-RevId: 347665956 Change-Id: I77dcf4be15e3961d2808b2d0ca323ab1a8d7a3a8
This commit is contained in:
parent
2dd9bc663e
commit
e54e7a08c5
@ -88,29 +88,31 @@ class GatherOp : public OpKernel {
|
||||
axis = params.dims() + axis;
|
||||
}
|
||||
|
||||
if (batch_dims_ != 0) {
|
||||
OP_REQUIRES(
|
||||
c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
|
||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||
-indices.dims(), ", ", indices.dims(),
|
||||
"], but got ", batch_dims_));
|
||||
// Modify only a local copy of batch_dims_.
|
||||
int32 batch_dims = batch_dims_;
|
||||
if (batch_dims != 0) {
|
||||
OP_REQUIRES(c,
|
||||
batch_dims >= -indices.dims() && batch_dims <= indices.dims(),
|
||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||
-indices.dims(), ", ", indices.dims(),
|
||||
"], but got ", batch_dims));
|
||||
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ = indices.dims() + batch_dims_;
|
||||
if (batch_dims < 0) {
|
||||
batch_dims = indices.dims() + batch_dims;
|
||||
}
|
||||
|
||||
if (!axis_is_set) axis = batch_dims_;
|
||||
if (!axis_is_set) axis = batch_dims;
|
||||
|
||||
OP_REQUIRES(c, batch_dims_ < params.dims(),
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
OP_REQUIRES(c, batch_dims < params.dims(),
|
||||
errors::InvalidArgument("batch_dims (", batch_dims,
|
||||
") must be less than rank(params) (",
|
||||
params.dims(), ")."));
|
||||
|
||||
OP_REQUIRES(c, axis >= batch_dims_,
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
OP_REQUIRES(c, axis >= batch_dims,
|
||||
errors::InvalidArgument("batch_dims (", batch_dims,
|
||||
") must be less than or equal to ",
|
||||
"axis (", axis, ")."));
|
||||
for (int i = 0; i < batch_dims_; ++i) {
|
||||
for (int i = 0; i < batch_dims; ++i) {
|
||||
OP_REQUIRES(c, params.dim_size(i) == indices.dim_size(i),
|
||||
errors::InvalidArgument(
|
||||
"params.shape[", i, "]: ", params.dim_size(i),
|
||||
@ -136,15 +138,15 @@ class GatherOp : public OpKernel {
|
||||
int64 outer_size = 1;
|
||||
int64 inner_size = 1;
|
||||
|
||||
for (int i = 0; i < batch_dims_; ++i) {
|
||||
for (int i = 0; i < batch_dims; ++i) {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
batch_size *= params.dim_size(i);
|
||||
}
|
||||
for (int i = batch_dims_; i < axis; ++i) {
|
||||
for (int i = batch_dims; i < axis; ++i) {
|
||||
result_shape.AddDim(params.dim_size(i));
|
||||
outer_size *= params.dim_size(i);
|
||||
}
|
||||
for (int i = batch_dims_; i < indices.dims(); ++i) {
|
||||
for (int i = batch_dims; i < indices.dims(); ++i) {
|
||||
result_shape.AddDim(indices.dim_size(i));
|
||||
}
|
||||
for (int i = axis + 1; i < params.dims(); ++i) {
|
||||
@ -159,7 +161,7 @@ class GatherOp : public OpKernel {
|
||||
|
||||
int64 bad_i = -1;
|
||||
auto indices_flat = indices.flat<Index>();
|
||||
if (batch_dims_ > 0) {
|
||||
if (batch_dims > 0) {
|
||||
auto params_flat = params.shaped<T, 4>(
|
||||
{batch_size, outer_size, gather_dim_size, inner_size});
|
||||
auto out_flat = out->shaped<T, 4>(
|
||||
|
Loading…
Reference in New Issue
Block a user