From 212d978a2d20c943724d0eb672aea9f7c4c222ba Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Tue, 21 Aug 2018 10:13:15 -0700 Subject: [PATCH] Do not increment the refcount of the tensorbuffer for ResourceStridedSliceAssign. This will trigger a copy if there are concurrent writes to the variable. PiperOrigin-RevId: 209615011 --- tensorflow/core/kernels/strided_slice_op.cc | 38 +++++++++++---------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 59fdc2262ab..7b537fef5be 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -300,7 +300,8 @@ class StridedSliceAssignOp : public OpKernel { gtl::InlinedVector end; gtl::InlinedVector strides; - Tensor old_lhs; + Tensor* old_lhs = nullptr; + Tensor tmp; if (context->input_dtype(0) == DT_RESOURCE) { Var* v; OP_REQUIRES_OK(context, @@ -308,29 +309,30 @@ class StridedSliceAssignOp : public OpKernel { mutex_lock ml(*v->mu()); OP_REQUIRES_OK(context, PrepareToUpdateVariable(context, v->tensor())); - old_lhs = *v->tensor(); - OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum::value, + old_lhs = v->tensor(); + OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum::value, errors::InvalidArgument( - "l-value dtype ", DataTypeString(old_lhs.dtype()), + "l-value dtype ", DataTypeString(old_lhs->dtype()), " does not match r-value dtype ", DataTypeString(DataTypeToEnum::value))); } else { context->forward_ref_input_to_ref_output(0, 0); - old_lhs = context->mutable_input(0, true); + tmp = context->mutable_input(0, true); + old_lhs = &tmp; } OP_REQUIRES_OK( - context, - ValidateStridedSliceOp( - &context->input(1), &context->input(2), context->input(3), - old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask, - shrink_axis_mask, &processing_shape, &final_shape, &is_identity, - &is_simple_slice, &slice_dim0, &begin, &end, &strides)); + context, ValidateStridedSliceOp( + &context->input(1), &context->input(2), context->input(3), + old_lhs->shape(), begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask, &processing_shape, + &final_shape, &is_identity, &is_simple_slice, &slice_dim0, + &begin, &end, &strides)); if (processing_shape.num_elements()) { const Tensor& input = context->input(4); TensorShape input_shape = input.shape(); - TensorShape original_shape = old_lhs.shape(); + TensorShape original_shape = old_lhs->shape(); // TODO(aselle): This check is too strong, we only should need // input_shape to be broadcastable to final_shape OP_REQUIRES( @@ -345,12 +347,12 @@ class StridedSliceAssignOp : public OpKernel { // scalar shape // Handle general dimensions -#define HANDLE_DIM(NDIM) \ - if (processing_dims == NDIM) { \ - HandleStridedSliceAssignCase()( \ - context, begin, end, strides, processing_shape, is_simple_slice, \ - &old_lhs); \ - return; \ +#define HANDLE_DIM(NDIM) \ + if (processing_dims == NDIM) { \ + HandleStridedSliceAssignCase()(context, begin, end, \ + strides, processing_shape, \ + is_simple_slice, old_lhs); \ + return; \ } HANDLE_DIM(0); HANDLE_DIM(1);