Merge pull request #27803 from eaplatanios:memory-bug-fix
PiperOrigin-RevId: 243663253
This commit is contained in:
commit
83d7c693de
@ -308,14 +308,16 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes());
|
||||
|
||||
if (forwarded_input == nullptr) {
|
||||
Tensor* out;
|
||||
// We were not able to forward the input, so we deep copy the tensor and
|
||||
// set the output.
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input.shape(), &old_lhs));
|
||||
context->allocate_output(0, input.shape(), &out));
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
tensorflow::functor::DoCopy(
|
||||
context->eigen_device<Device>(), input, old_lhs));
|
||||
context->eigen_device<Device>(), input, out));
|
||||
old_lhs = out;
|
||||
} else {
|
||||
old_lhs = forwarded_input.get();
|
||||
}
|
||||
@ -429,7 +431,6 @@ class StridedSliceAssignOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("begin") \
|
||||
.HostMemory("end") \
|
||||
.HostMemory("strides"), \
|
||||
@ -475,7 +476,6 @@ TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("begin") \
|
||||
.HostMemory("end") \
|
||||
.HostMemory("strides"), \
|
||||
@ -573,7 +573,6 @@ REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
||||
.Device(DEVICE_SYCL) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("input") \
|
||||
.HostMemory("begin") \
|
||||
.HostMemory("end") \
|
||||
.HostMemory("strides"), \
|
||||
@ -619,7 +618,6 @@ REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
|
||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
|
||||
.Device(DEVICE_SYCL)
|
||||
.TypeConstraint<int32>("T")
|
||||
.HostMemory("input")
|
||||
.HostMemory("begin")
|
||||
.HostMemory("end")
|
||||
.HostMemory("strides"),
|
||||
|
@ -1653,6 +1653,7 @@ REGISTER_OP("TensorStridedSliceUpdate")
|
||||
.Input("end: Index")
|
||||
.Input("strides: Index")
|
||||
.Input("value: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Index: {int32, int64}")
|
||||
.Attr("begin_mask: int = 0")
|
||||
@ -1660,7 +1661,7 @@ REGISTER_OP("TensorStridedSliceUpdate")
|
||||
.Attr("ellipsis_mask: int = 0")
|
||||
.Attr("new_axis_mask: int = 0")
|
||||
.Attr("shrink_axis_mask: int = 0")
|
||||
.SetShapeFn(shape_inference::NoOutputs);
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("Tile")
|
||||
.Input("input: T")
|
||||
|
@ -86777,6 +86777,10 @@ op {
|
||||
name: "value"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
|
Loading…
Reference in New Issue
Block a user