Only register the _Arg and _Retval kernel for POD types on sycl
This commit is contained in:
parent
d9037a06b4
commit
fc9bde9c06
@ -87,8 +87,28 @@ class RetvalOp : public OpKernel {
|
|||||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
|
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
|
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_SYCL), ArgOp);
|
#if TENSORFLOW_USE_SYCL
|
||||||
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_SYCL), RetvalOp);
|
#define REGISTER(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
|
||||||
|
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||||
|
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
|
||||||
|
.Device(DEVICE_GPU)
|
||||||
|
.HostMemory("output")
|
||||||
|
.TypeConstraint<int32>("T"),
|
||||||
|
ArgOp);
|
||||||
|
#undef REGISTER
|
||||||
|
#define REGISTER(type) \
|
||||||
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
|
||||||
|
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
|
||||||
|
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
|
||||||
|
.Device(DEVICE_GPU)
|
||||||
|
.HostMemory("input")
|
||||||
|
.TypeConstraint<int32>("T"),
|
||||||
|
RetvalOp);
|
||||||
|
#undef REGISTER
|
||||||
|
#endif
|
||||||
|
|
||||||
#define REGISTER(type) \
|
#define REGISTER(type) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
|
Loading…
Reference in New Issue
Block a user