From fc9bde9c0675116490d204c21f81c764691503f9 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Wed, 2 Nov 2016 19:40:19 -0700 Subject: [PATCH] Only register the _Arg and _Retval kernel for POD types on sycl --- tensorflow/core/kernels/function_ops.cc | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 7d538d2924f..7cb9a3a6573 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -87,8 +87,28 @@ class RetvalOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); -REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_SYCL), ArgOp); -REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_SYCL), RetvalOp); +#if TENSORFLOW_USE_SYCL +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_Arg").Device(DEVICE_SYCL).TypeConstraint("T"), ArgOp); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) + TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint("T"), + ArgOp); +#undef REGISTER +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_Retval").Device(DEVICE_SYCL).TypeConstraint("T"), RetvalOp); + TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) + TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") + .Device(DEVICE_GPU) + .HostMemory("input") + .TypeConstraint("T"), + RetvalOp); +#undef REGISTER +#endif #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \