Do not increment the refcount of the tensorbuffer for ResourceStridedSliceAssign.

This will trigger a copy if there are concurrent writes to the variable.

PiperOrigin-RevId: 209615011
This commit is contained in:
Alexandre Passos 2018-08-21 10:13:15 -07:00 committed by TensorFlower Gardener
parent 99107ec949
commit 212d978a2d

View File

@ -300,7 +300,8 @@ class StridedSliceAssignOp : public OpKernel {
gtl::InlinedVector<int64, 4> end;
gtl::InlinedVector<int64, 4> strides;
Tensor old_lhs;
Tensor* old_lhs = nullptr;
Tensor tmp;
if (context->input_dtype(0) == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(context,
@ -308,29 +309,30 @@ class StridedSliceAssignOp : public OpKernel {
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
old_lhs = *v->tensor();
OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
old_lhs = v->tensor();
OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
"l-value dtype ", DataTypeString(old_lhs.dtype()),
"l-value dtype ", DataTypeString(old_lhs->dtype()),
" does not match r-value dtype ",
DataTypeString(DataTypeToEnum<T>::value)));
} else {
context->forward_ref_input_to_ref_output(0, 0);
old_lhs = context->mutable_input(0, true);
tmp = context->mutable_input(0, true);
old_lhs = &tmp;
}
OP_REQUIRES_OK(
context,
ValidateStridedSliceOp(
&context->input(1), &context->input(2), context->input(3),
old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
&is_simple_slice, &slice_dim0, &begin, &end, &strides));
context, ValidateStridedSliceOp(
&context->input(1), &context->input(2), context->input(3),
old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
new_axis_mask, shrink_axis_mask, &processing_shape,
&final_shape, &is_identity, &is_simple_slice, &slice_dim0,
&begin, &end, &strides));
if (processing_shape.num_elements()) {
const Tensor& input = context->input(4);
TensorShape input_shape = input.shape();
TensorShape original_shape = old_lhs.shape();
TensorShape original_shape = old_lhs->shape();
// TODO(aselle): This check is too strong, we only should need
// input_shape to be broadcastable to final_shape
OP_REQUIRES(
@ -345,12 +347,12 @@ class StridedSliceAssignOp : public OpKernel {
// scalar shape
// Handle general dimensions
#define HANDLE_DIM(NDIM) \
if (processing_dims == NDIM) { \
HandleStridedSliceAssignCase<Device, T, NDIM>()( \
context, begin, end, strides, processing_shape, is_simple_slice, \
&old_lhs); \
return; \
#define HANDLE_DIM(NDIM) \
if (processing_dims == NDIM) { \
HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end, \
strides, processing_shape, \
is_simple_slice, old_lhs); \
return; \
}
HANDLE_DIM(0);
HANDLE_DIM(1);