Merge pull request #27803 from eaplatanios:memory-bug-fix
PiperOrigin-RevId: 243663253
This commit is contained in:
commit
83d7c693de
@ -308,14 +308,16 @@ class StridedSliceAssignOp : public OpKernel {
|
|||||||
0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes());
|
0, 0, input.dtype(), shape, DEVICE_MEMORY, AllocatorAttributes());
|
||||||
|
|
||||||
if (forwarded_input == nullptr) {
|
if (forwarded_input == nullptr) {
|
||||||
|
Tensor* out;
|
||||||
// We were not able to forward the input, so we deep copy the tensor and
|
// We were not able to forward the input, so we deep copy the tensor and
|
||||||
// set the output.
|
// set the output.
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
context->allocate_output(0, input.shape(), &old_lhs));
|
context->allocate_output(0, input.shape(), &out));
|
||||||
|
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
tensorflow::functor::DoCopy(
|
tensorflow::functor::DoCopy(
|
||||||
context->eigen_device<Device>(), input, old_lhs));
|
context->eigen_device<Device>(), input, out));
|
||||||
|
old_lhs = out;
|
||||||
} else {
|
} else {
|
||||||
old_lhs = forwarded_input.get();
|
old_lhs = forwarded_input.get();
|
||||||
}
|
}
|
||||||
@ -429,7 +431,6 @@ class StridedSliceAssignOp : public OpKernel {
|
|||||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
||||||
.Device(DEVICE_CPU) \
|
.Device(DEVICE_CPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.HostMemory("input") \
|
|
||||||
.HostMemory("begin") \
|
.HostMemory("begin") \
|
||||||
.HostMemory("end") \
|
.HostMemory("end") \
|
||||||
.HostMemory("strides"), \
|
.HostMemory("strides"), \
|
||||||
@ -475,7 +476,6 @@ TF_CALL_ALL_TYPES(REGISTER_STRIDED_SLICE);
|
|||||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.HostMemory("input") \
|
|
||||||
.HostMemory("begin") \
|
.HostMemory("begin") \
|
||||||
.HostMemory("end") \
|
.HostMemory("end") \
|
||||||
.HostMemory("strides"), \
|
.HostMemory("strides"), \
|
||||||
@ -573,7 +573,6 @@ REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
|
|||||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate") \
|
||||||
.Device(DEVICE_SYCL) \
|
.Device(DEVICE_SYCL) \
|
||||||
.TypeConstraint<type>("T") \
|
.TypeConstraint<type>("T") \
|
||||||
.HostMemory("input") \
|
|
||||||
.HostMemory("begin") \
|
.HostMemory("begin") \
|
||||||
.HostMemory("end") \
|
.HostMemory("end") \
|
||||||
.HostMemory("strides"), \
|
.HostMemory("strides"), \
|
||||||
@ -619,7 +618,6 @@ REGISTER_KERNEL_BUILDER(Name("ResourceStridedSliceAssign")
|
|||||||
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
|
REGISTER_KERNEL_BUILDER(Name("TensorStridedSliceUpdate")
|
||||||
.Device(DEVICE_SYCL)
|
.Device(DEVICE_SYCL)
|
||||||
.TypeConstraint<int32>("T")
|
.TypeConstraint<int32>("T")
|
||||||
.HostMemory("input")
|
|
||||||
.HostMemory("begin")
|
.HostMemory("begin")
|
||||||
.HostMemory("end")
|
.HostMemory("end")
|
||||||
.HostMemory("strides"),
|
.HostMemory("strides"),
|
||||||
|
@ -1653,6 +1653,7 @@ REGISTER_OP("TensorStridedSliceUpdate")
|
|||||||
.Input("end: Index")
|
.Input("end: Index")
|
||||||
.Input("strides: Index")
|
.Input("strides: Index")
|
||||||
.Input("value: T")
|
.Input("value: T")
|
||||||
|
.Output("output: T")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
.Attr("Index: {int32, int64}")
|
.Attr("Index: {int32, int64}")
|
||||||
.Attr("begin_mask: int = 0")
|
.Attr("begin_mask: int = 0")
|
||||||
@ -1660,7 +1661,7 @@ REGISTER_OP("TensorStridedSliceUpdate")
|
|||||||
.Attr("ellipsis_mask: int = 0")
|
.Attr("ellipsis_mask: int = 0")
|
||||||
.Attr("new_axis_mask: int = 0")
|
.Attr("new_axis_mask: int = 0")
|
||||||
.Attr("shrink_axis_mask: int = 0")
|
.Attr("shrink_axis_mask: int = 0")
|
||||||
.SetShapeFn(shape_inference::NoOutputs);
|
.SetShapeFn(shape_inference::UnchangedShape);
|
||||||
|
|
||||||
REGISTER_OP("Tile")
|
REGISTER_OP("Tile")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
|
@ -86777,6 +86777,10 @@ op {
|
|||||||
name: "value"
|
name: "value"
|
||||||
type_attr: "T"
|
type_attr: "T"
|
||||||
}
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
attr {
|
attr {
|
||||||
name: "T"
|
name: "T"
|
||||||
type: "type"
|
type: "type"
|
||||||
|
Loading…
Reference in New Issue
Block a user