Do not increment the refcount of the tensorbuffer for ResourceStridedSliceAssign.
This will trigger a copy if there are concurrent writes to the variable. PiperOrigin-RevId: 209615011
This commit is contained in:
parent
99107ec949
commit
212d978a2d
@ -300,7 +300,8 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
gtl::InlinedVector<int64, 4> end;
|
||||
gtl::InlinedVector<int64, 4> strides;
|
||||
|
||||
Tensor old_lhs;
|
||||
Tensor* old_lhs = nullptr;
|
||||
Tensor tmp;
|
||||
if (context->input_dtype(0) == DT_RESOURCE) {
|
||||
Var* v;
|
||||
OP_REQUIRES_OK(context,
|
||||
@ -308,29 +309,30 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
mutex_lock ml(*v->mu());
|
||||
OP_REQUIRES_OK(context,
|
||||
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
|
||||
old_lhs = *v->tensor();
|
||||
OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
|
||||
old_lhs = v->tensor();
|
||||
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
|
||||
errors::InvalidArgument(
|
||||
"l-value dtype ", DataTypeString(old_lhs.dtype()),
|
||||
"l-value dtype ", DataTypeString(old_lhs->dtype()),
|
||||
" does not match r-value dtype ",
|
||||
DataTypeString(DataTypeToEnum<T>::value)));
|
||||
} else {
|
||||
context->forward_ref_input_to_ref_output(0, 0);
|
||||
old_lhs = context->mutable_input(0, true);
|
||||
tmp = context->mutable_input(0, true);
|
||||
old_lhs = &tmp;
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
ValidateStridedSliceOp(
|
||||
&context->input(1), &context->input(2), context->input(3),
|
||||
old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
|
||||
shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
|
||||
&is_simple_slice, &slice_dim0, &begin, &end, &strides));
|
||||
context, ValidateStridedSliceOp(
|
||||
&context->input(1), &context->input(2), context->input(3),
|
||||
old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
|
||||
new_axis_mask, shrink_axis_mask, &processing_shape,
|
||||
&final_shape, &is_identity, &is_simple_slice, &slice_dim0,
|
||||
&begin, &end, &strides));
|
||||
|
||||
if (processing_shape.num_elements()) {
|
||||
const Tensor& input = context->input(4);
|
||||
TensorShape input_shape = input.shape();
|
||||
TensorShape original_shape = old_lhs.shape();
|
||||
TensorShape original_shape = old_lhs->shape();
|
||||
// TODO(aselle): This check is too strong, we only should need
|
||||
// input_shape to be broadcastable to final_shape
|
||||
OP_REQUIRES(
|
||||
@ -345,12 +347,12 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
// scalar shape
|
||||
|
||||
// Handle general dimensions
|
||||
#define HANDLE_DIM(NDIM) \
|
||||
if (processing_dims == NDIM) { \
|
||||
HandleStridedSliceAssignCase<Device, T, NDIM>()( \
|
||||
context, begin, end, strides, processing_shape, is_simple_slice, \
|
||||
&old_lhs); \
|
||||
return; \
|
||||
#define HANDLE_DIM(NDIM) \
|
||||
if (processing_dims == NDIM) { \
|
||||
HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end, \
|
||||
strides, processing_shape, \
|
||||
is_simple_slice, old_lhs); \
|
||||
return; \
|
||||
}
|
||||
HANDLE_DIM(0);
|
||||
HANDLE_DIM(1);
|
||||
|
Loading…
x
Reference in New Issue
Block a user