diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc
index 7d538d2924f..7cb9a3a6573 100644
--- a/tensorflow/core/kernels/function_ops.cc
+++ b/tensorflow/core/kernels/function_ops.cc
@@ -87,8 +87,28 @@ class RetvalOp : public OpKernel {
 REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp);
 REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp);
 
-REGISTER_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_SYCL), ArgOp);
-REGISTER_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_SYCL), RetvalOp);
+#if TENSORFLOW_USE_SYCL
+#define REGISTER(type)     \
+  REGISTER_KERNEL_BUILDER( \
+      Name("_Arg").Device(DEVICE_SYCL).TypeConstraint<type>("T"), ArgOp);
+  TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
+  TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg")
+						 .Device(DEVICE_GPU)
+						 .HostMemory("output")
+						 .TypeConstraint<int32>("T"),
+						 ArgOp);
+#undef REGISTER
+#define REGISTER(type)     \
+  REGISTER_KERNEL_BUILDER( \
+      Name("_Retval").Device(DEVICE_SYCL).TypeConstraint<type>("T"), RetvalOp);
+  TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
+  TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval")
+						 .Device(DEVICE_GPU)
+						 .HostMemory("input")
+						 .TypeConstraint<int32>("T"),
+						 RetvalOp);
+#undef REGISTER
+#endif
 
 #define REGISTER(type)     \
   REGISTER_KERNEL_BUILDER( \