Only register the _Arg and _Retval kernel for POD types on sycl
This commit is contained in:
parent
d9037a06b4
commit
fc9bde9c06
@ -87,8 +87,28 @@ class RetvalOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_SYCL), ArgOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_SYCL), RetvalOp);
|
||||
#if TENSORFLOW_USE_SYCL
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("output")
|
||||
.TypeConstraint<int32>("T"),
|
||||
ArgOp);
|
||||
#undef REGISTER
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
|
||||
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("input")
|
||||
.TypeConstraint<int32>("T"),
|
||||
RetvalOp);
|
||||
#undef REGISTER
|
||||
#endif
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
|
Loading…
Reference in New Issue
Block a user