diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 7d538d2924f..7cb9a3a6573 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -87,8 +87,28 @@ class RetvalOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); -REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_SYCL), ArgOp); -REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_SYCL), RetvalOp); +#if TENSORFLOW_USE_SYCL +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) + TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint<int32>("T"), + ArgOp); +#undef REGISTER +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) + TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") + .Device(DEVICE_GPU) + .HostMemory("input") + .TypeConstraint<int32>("T"), + RetvalOp); +#undef REGISTER +#endif #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \