Merge pull request #27803 from eaplatanios:memory-bug-fix

PiperOrigin-RevId: 243663253
This commit is contained in:
TensorFlower Gardener 2019-04-15 12:27:55 -07:00
commit 83d7c693de
3 changed files with 10 additions and 7 deletions

View File

@ -308,14 +308,16 @@ class StridedSliceAssignOp : public OpKernel {
0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes()); 0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes());
if (forwarded_input == nullptr) { if (forwarded_input == nullptr) {
Tensor* out;
// We were not able to forward the input, so we deep copy the tensor and // We were not able to forward the input, so we deep copy the tensor and
// set the output. // set the output.
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->allocate_output(0, input.shape(), &old_lhs)); context->allocate_output(0, input.shape(), &out));
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
tensorflow::functor::DoCopy( tensorflow::functor::DoCopy(
context->eigen_device<Device>(), input, old_lhs)); context->eigen_device<Device>(), input, out));
old_lhs = out;
} else { } else {
old_lhs = forwarded_input.get(); old_lhs = forwarded_input.get();
} }
@ -429,7 +431,6 @@ class StridedSliceAssignOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \ .TypeConstraint<type>("T") \
.HostMemory("input") \
.HostMemory("begin") \ .HostMemory("begin") \
.HostMemory("end") \ .HostMemory("end") \
.HostMemory("strides"), \ .HostMemory("strides"), \
@ -475,7 +476,6 @@ TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
.Device(DEVICE_GPU) \ .Device(DEVICE_GPU) \
.TypeConstraint<type>("T") \ .TypeConstraint<type>("T") \
.HostMemory("input") \
.HostMemory("begin") \ .HostMemory("begin") \
.HostMemory("end") \ .HostMemory("end") \
.HostMemory("strides"), \ .HostMemory("strides"), \
@ -573,7 +573,6 @@ REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \ REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
.Device(DEVICE_SYCL) \ .Device(DEVICE_SYCL) \
.TypeConstraint<type>("T") \ .TypeConstraint<type>("T") \
.HostMemory("input") \
.HostMemory("begin") \ .HostMemory("begin") \
.HostMemory("end") \ .HostMemory("end") \
.HostMemory("strides"), \ .HostMemory("strides"), \
@ -619,7 +618,6 @@ REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
.Device(DEVICE_SYCL) .Device(DEVICE_SYCL)
.TypeConstraint<int32>("T") .TypeConstraint<int32>("T")
.HostMemory("input")
.HostMemory("begin") .HostMemory("begin")
.HostMemory("end") .HostMemory("end")
.HostMemory("strides"), .HostMemory("strides"),

View File

@ -1653,6 +1653,7 @@ REGISTER_OP("TensorStridedSliceUpdate")
.Input("end: Index") .Input("end: Index")
.Input("strides: Index") .Input("strides: Index")
.Input("value: T") .Input("value: T")
.Output("output: T")
.Attr("T: type") .Attr("T: type")
.Attr("Index: {int32, int64}") .Attr("Index: {int32, int64}")
.Attr("begin_mask: int = 0") .Attr("begin_mask: int = 0")
@ -1660,7 +1661,7 @@ REGISTER_OP("TensorStridedSliceUpdate")
.Attr("ellipsis_mask: int = 0") .Attr("ellipsis_mask: int = 0")
.Attr("new_axis_mask: int = 0") .Attr("new_axis_mask: int = 0")
.Attr("shrink_axis_mask: int = 0") .Attr("shrink_axis_mask: int = 0")
.SetShapeFn(shape_inference::NoOutputs); .SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("Tile") REGISTER_OP("Tile")
.Input("input: T") .Input("input: T")

View File

@ -86777,6 +86777,10 @@ op {
name: "value" name: "value"
type_attr: "T" type_attr: "T"
} }
output_arg {
name: "output"
type_attr: "T"
}
attr { attr {
name: "T" name: "T"
type: "type" type: "type"