Arg/Ret ops should return/take tensors in device memory
PiperOrigin-RevId: 238643502
This commit is contained in:
parent
9a43dfeac5
commit
ea6c5f5cc3
@ -199,9 +199,7 @@ class XlaAssignVariableOp : public OpKernel {
|
|||||||
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
|
Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp); \
|
||||||
\
|
\
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name(kArgOp).Device(DEVICE).HostMemory("output").TypeConstraint("T", \
|
Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp); \
|
||||||
TYPES), \
|
|
||||||
ArgOp); \
|
|
||||||
REGISTER_KERNEL_BUILDER(Name(kArgOp) \
|
REGISTER_KERNEL_BUILDER(Name(kArgOp) \
|
||||||
.Device(DEVICE) \
|
.Device(DEVICE) \
|
||||||
.HostMemory("output") \
|
.HostMemory("output") \
|
||||||
@ -210,11 +208,8 @@ class XlaAssignVariableOp : public OpKernel {
|
|||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp); \
|
Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp); \
|
||||||
\
|
\
|
||||||
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
.Device(DEVICE) \
|
Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp); \
|
||||||
.TypeConstraint("T", TYPES) \
|
|
||||||
.HostMemory("input"), \
|
|
||||||
RetvalOp); \
|
|
||||||
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
|
REGISTER_KERNEL_BUILDER(Name(kRetOp) \
|
||||||
.Device(DEVICE) \
|
.Device(DEVICE) \
|
||||||
.TypeConstraint<ResourceHandle>("T") \
|
.TypeConstraint<ResourceHandle>("T") \
|
||||||
|
Loading…
Reference in New Issue
Block a user