Only register the _Arg and _Retval kernel for POD types on sycl

This commit is contained in:
Benoit Steiner 2016-11-02 19:40:19 -07:00
parent d9037a06b4
commit fc9bde9c06

View File

@ -87,8 +87,28 @@ class RetvalOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_SYCL), ArgOp);
REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_SYCL), RetvalOp);
#if TENSORFLOW_USE_SYCL
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
.Device(DEVICE_GPU)
.HostMemory("output")
.TypeConstraint<int32>("T"),
ArgOp);
#undef REGISTER
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
.Device(DEVICE_GPU)
.HostMemory("input")
.TypeConstraint<int32>("T"),
RetvalOp);
#undef REGISTER
#endif
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \