diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index a98e0b7fd87..e558e6e80a2 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -116,6 +116,8 @@ load( "//third_party/mkl:build_defs.bzl", "if_mkl", ) +load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl") + # ----------------------------------------------------------------------------- # Public targets @@ -729,7 +731,7 @@ cc_library( "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/platform/default/build_config:gtest", - ], + ] + if_sycl([":sycl_runtime"]), ) # This is a link-only library to provide a DirectSession diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc index 0d238276f4f..b7ef9361e95 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.cc @@ -25,6 +25,9 @@ string SYCLAllocator::Name() { return "device:SYCL"; } void *SYCLAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { assert(device_); + if (num_bytes == 0) { + return device_->allocate(1); + } auto p = device_->allocate(num_bytes); return p; } @@ -42,6 +45,6 @@ void SYCLAllocator::EnterLameDuckMode() { } } -} // namespace tensorflow +} // namespace tensorflow -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h index c896f7f6037..15d9ab41a46 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h +++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h @@ -27,8 +27,8 @@ limitations under the License. namespace tensorflow { class SYCLAllocator : public Allocator { -public: - SYCLAllocator(Eigen::QueueInterface* device) : device_(device) {} + public: + SYCLAllocator(Eigen::QueueInterface *device) : device_(device) {} virtual ~SYCLAllocator() override; string Name() override; void *AllocateRaw(size_t alignment, size_t num_bytes) override; @@ -36,11 +36,12 @@ public: void EnterLameDuckMode(); virtual bool ShouldAllocateEmptyTensors() override final { return true; } -private: + + private: Eigen::QueueInterface *device_; // not owned TF_DISALLOW_COPY_AND_ASSIGN(SYCLAllocator); }; -} // namespace tensorflow +} // namespace tensorflow -#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ +#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.cc b/tensorflow/core/common_runtime/sycl/sycl_device.cc index 0abe25c373e..2c2185b2c03 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device.cc @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { -static std::unordered_set live_devices; +static std::unordered_set live_devices; static bool first_time = true; void ShutdownSycl() { diff --git a/tensorflow/core/common_runtime/sycl/sycl_device.h b/tensorflow/core/common_runtime/sycl/sycl_device.h index b5a72d94763..a5c7c5f0ec7 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device.h +++ b/tensorflow/core/common_runtime/sycl/sycl_device.h @@ -34,10 +34,11 @@ class SYCLDevice : public LocalDevice { Bytes memory_limit, const DeviceLocality &locality, const string &physical_device_desc, SYCLSelector sycl_selector, Allocator *cpu_allocator) - : LocalDevice(options, Device::BuildDeviceAttributes( - name, DEVICE_SYCL, memory_limit, locality, - physical_device_desc), - nullptr), + : LocalDevice( + options, + Device::BuildDeviceAttributes(name, DEVICE_SYCL, memory_limit, + locality, physical_device_desc), + nullptr), cpu_allocator_(cpu_allocator), sycl_queue_(new Eigen::QueueInterface(sycl_selector)), sycl_device_(new Eigen::SyclDevice(sycl_queue_)), diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc index a6be9195d4b..1c868f5606e 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_context.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.cc @@ -17,8 +17,8 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h" #include "tensorflow/core/common_runtime/dma_helper.h" +#include "tensorflow/core/common_runtime/sycl/sycl_device_context.h" namespace tensorflow { @@ -31,68 +31,68 @@ void SYCLDeviceContext::CopyCPUTensorToDevice(const Tensor *cpu_tensor, const void *src_ptr = DMAHelper::base(cpu_tensor); void *dst_ptr = DMAHelper::base(device_tensor); switch (cpu_tensor->dtype()) { - case DT_FLOAT: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_DOUBLE: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_INT32: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_INT64: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_HALF: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), - static_cast(src_ptr), total_bytes); - break; - case DT_COMPLEX64: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast *>(dst_ptr), - static_cast *>(src_ptr), total_bytes); - break; - case DT_COMPLEX128: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast *>(dst_ptr), - static_cast *>(src_ptr), total_bytes); - break; - case DT_INT8: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_INT16: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_UINT8: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_UINT16: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_BOOL: - device->eigen_sycl_device()->memcpyHostToDevice( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - default: - assert(false && "unsupported type"); + case DT_FLOAT: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_DOUBLE: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), + static_cast(src_ptr), total_bytes); + break; + case DT_INT32: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_INT64: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_HALF: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), + static_cast(src_ptr), total_bytes); + break; + case DT_COMPLEX64: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast *>(dst_ptr), + static_cast *>(src_ptr), total_bytes); + break; + case DT_COMPLEX128: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast *>(dst_ptr), + static_cast *>(src_ptr), total_bytes); + break; + case DT_INT8: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_INT16: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_UINT8: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_UINT16: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), + static_cast(src_ptr), total_bytes); + break; + case DT_BOOL: + device->eigen_sycl_device()->memcpyHostToDevice( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + default: + assert(false && "unsupported type"); } } device->eigen_sycl_device()->synchronize(); @@ -106,71 +106,71 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor, StatusCallback done) { const int64 total_bytes = device_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = DMAHelper::base(device_tensor); - void* dst_ptr = DMAHelper::base(cpu_tensor); + const void *src_ptr = DMAHelper::base(device_tensor); + void *dst_ptr = DMAHelper::base(cpu_tensor); switch (device_tensor->dtype()) { - case DT_FLOAT: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_DOUBLE: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_INT32: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_INT64: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_HALF: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), - static_cast(src_ptr), total_bytes); - break; - case DT_COMPLEX64: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast *>(dst_ptr), - static_cast *>(src_ptr), total_bytes); - break; - case DT_COMPLEX128: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast *>(dst_ptr), - static_cast *>(src_ptr), total_bytes); - break; - case DT_INT8: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_INT16: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_UINT8: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_UINT16: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - case DT_BOOL: - device->eigen_sycl_device()->memcpyDeviceToHost( - static_cast(dst_ptr), static_cast(src_ptr), - total_bytes); - break; - default: - assert(false && "unsupported type"); + case DT_FLOAT: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_DOUBLE: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), + static_cast(src_ptr), total_bytes); + break; + case DT_INT32: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_INT64: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_HALF: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), + static_cast(src_ptr), total_bytes); + break; + case DT_COMPLEX64: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast *>(dst_ptr), + static_cast *>(src_ptr), total_bytes); + break; + case DT_COMPLEX128: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast *>(dst_ptr), + static_cast *>(src_ptr), total_bytes); + break; + case DT_INT8: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_INT16: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_UINT8: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + case DT_UINT16: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), + static_cast(src_ptr), total_bytes); + break; + case DT_BOOL: + device->eigen_sycl_device()->memcpyDeviceToHost( + static_cast(dst_ptr), static_cast(src_ptr), + total_bytes); + break; + default: + assert(false && "unsupported type"); } } device->eigen_sycl_device()->synchronize(); @@ -178,4 +178,4 @@ void SYCLDeviceContext::CopyDeviceTensorToCPU(const Tensor *device_tensor, } } // namespace tensorflow -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_context.h b/tensorflow/core/common_runtime/sycl/sycl_device_context.h index 1f7ad543d94..0f8f17b8058 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_context.h +++ b/tensorflow/core/common_runtime/sycl/sycl_device_context.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { class SYCLDeviceContext : public DeviceContext { -public: + public: SYCLDeviceContext() {} ~SYCLDeviceContext() override {} @@ -40,6 +40,6 @@ public: StatusCallback done) override; }; -} // namespace tensorflow +} // namespace tensorflow -#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_ +#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_DEVICE_CONTEXT_H_ diff --git a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc index 51eb4973d8a..a643fc72580 100644 --- a/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc +++ b/tensorflow/core/common_runtime/sycl/sycl_device_factory.cc @@ -21,7 +21,7 @@ limitations under the License. namespace tensorflow { class SYCLDeviceFactory : public DeviceFactory { -public: + public: Status CreateDevices(const SessionOptions &options, const string &name_prefix, std::vector *devices) override { int n = 1; @@ -31,10 +31,10 @@ public: } for (int i = 0; i < n; i++) { string name = strings::StrCat(name_prefix, "/device:SYCL:", i); - devices->push_back(new SYCLDevice(options, name, Bytes(256 << 20), - DeviceLocality(), - SYCLDevice::GetShortDeviceDescription(), - cl::sycl::gpu_selector(), cpu_allocator())); + devices->push_back( + new SYCLDevice(options, name, Bytes(256 << 20), DeviceLocality(), + SYCLDevice::GetShortDeviceDescription(), + cl::sycl::gpu_selector(), cpu_allocator())); } return Status::OK(); } @@ -43,4 +43,4 @@ public: REGISTER_LOCAL_DEVICE_FACTORY("SYCL", SYCLDeviceFactory, 200); } -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h index 8f8d9fd08e6..c1fe5517c69 100644 --- a/tensorflow/core/framework/register_types_traits.h +++ b/tensorflow/core/framework/register_types_traits.h @@ -21,6 +21,10 @@ limitations under the License. typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/platform/types.h" @@ -66,6 +70,17 @@ struct proxy_type_pod { typedef Eigen::half type; }; +#ifdef TENSORFLOW_USE_SYCL +template <> +struct proxy_type_pod { + typedef double type; +}; +template <> +struct proxy_type_pod { + typedef float type; +}; +#endif // TENSORFLOW_USE_SYCL + /// If POD we use proxy_type_pod, otherwise this maps to identiy. template struct proxy_type { @@ -81,6 +96,10 @@ struct proxy_type { TF_CALL_int8(m) TF_CALL_complex128(m) #define TF_CALL_GPU_PROXY_TYPES(m) \ TF_CALL_double(m) TF_CALL_float(m) TF_CALL_half(m) TF_CALL_int32(m) +#ifdef TENSORFLOW_USE_SYCL +#define TF_CALL_SYCL_PROXY_TYPES(m) \ + TF_CALL_double(m) TF_CALL_float(m) TF_CALL_int32(m) +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_ diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index ab82c247d65..562934ed63b 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -34,6 +34,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL #define CURRY_TYPES2(FN, arg0) \ FN(arg0, bool); \ @@ -206,6 +209,52 @@ REGISTER_CAST_GPU(bfloat16, float); #undef REGISTER_CAST_GPU #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +class SyclCastOp : public CastOpBase { + public: + explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { + OP_REQUIRES_OK(ctx, Prepare()); + } + + private: + Status Prepare() { + if (src_dtype_ == dst_dtype_) { + work_ = nullptr; // Identity + return Status::OK(); + } + if (src_dtype_ == DT_BOOL) { + work_ = GetSyclCastFromBool(dst_dtype_); + } else if (src_dtype_ == DT_INT32) { + work_ = GetSyclCastFromInt32(dst_dtype_); + } else if (src_dtype_ == DT_INT64) { + work_ = GetSyclCastFromInt64(dst_dtype_); + } else if (src_dtype_ == DT_FLOAT) { + work_ = GetSyclCastFromFloat(dst_dtype_); + } else if (src_dtype_ == DT_DOUBLE) { + work_ = GetSyclCastFromDouble(dst_dtype_); + } + + return work_ == nullptr ? Unimplemented() : Status::OK(); + } +}; + +#define REGISTER_CAST_SYCL(srctype, dsttype) \ + REGISTER_KERNEL_BUILDER(Name("Cast") \ + .TypeConstraint("SrcT") \ + .TypeConstraint("DstT") \ + .Device(DEVICE_SYCL), \ + SyclCastOp) + +CURRY_TYPES2(REGISTER_CAST_SYCL, bool); +CURRY_TYPES2(REGISTER_CAST_SYCL, int32); +CURRY_TYPES2(REGISTER_CAST_SYCL, int64); +CURRY_TYPES2(REGISTER_CAST_SYCL, float); +CURRY_TYPES2(REGISTER_CAST_SYCL, double); + +#undef REGISTER_CAST_SYCL + +#endif // TENSORFLOW_USE_SYCL + #undef CURRY_TYPES2 // HostCast differs from Cast in that its input and output are in host memory. @@ -213,5 +262,10 @@ REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); REGISTER_KERNEL_BUILDER( Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"), CpuCastOp); - +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"), + CpuCastOp); +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow + diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h index cb7cc81937a..1ee0796ac14 100644 --- a/tensorflow/core/kernels/cast_op_impl.h +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -33,6 +33,16 @@ struct CastFunctor { } }; +#ifdef TENSORFLOW_USE_SYCL +template +struct CastFunctor { + void operator()(const Eigen::SyclDevice& d, typename TTypes::Flat o, + typename TTypes::ConstFlat i) { + o.device(d) = i.template cast(); + } +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor #define CURRY_TYPES3(FN, arg0, arg1) \ @@ -140,6 +150,25 @@ GetGpuCastFromBfloat(DataType dst_dtype); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +std::function +GetSyclCastFromBool(DataType dst_dtype); + +std::function +GetSyclCastFromInt32(DataType dst_dtype); + +std::function +GetSyclCastFromInt64(DataType dst_dtype); + +std::function +GetSyclCastFromFloat(DataType dst_dtype); + +std::function +GetSyclCastFromDouble(DataType dst_dtype); + +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ + diff --git a/tensorflow/core/kernels/cast_op_impl_bool.cc b/tensorflow/core/kernels/cast_op_impl_bool.cc index 92fee89a475..a13f1630092 100644 --- a/tensorflow/core/kernels/cast_op_impl_bool.cc +++ b/tensorflow/core/kernels/cast_op_impl_bool.cc @@ -34,4 +34,14 @@ GetGpuCastFromBool(DataType dst_dtype) { } #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +std::function +GetSyclCastFromBool(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, SYCLDevice, bool); + return nullptr; +} +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow + diff --git a/tensorflow/core/kernels/cast_op_impl_double.cc b/tensorflow/core/kernels/cast_op_impl_double.cc index fd20061d216..fdc8d51158f 100644 --- a/tensorflow/core/kernels/cast_op_impl_double.cc +++ b/tensorflow/core/kernels/cast_op_impl_double.cc @@ -34,4 +34,14 @@ GetGpuCastFromDouble(DataType dst_dtype) { } #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +std::function +GetSyclCastFromDouble(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, SYCLDevice, double); + return nullptr; +} +#endif // TENSORFLOW_USE_SYC + } // namespace tensorflow + diff --git a/tensorflow/core/kernels/cast_op_impl_float.cc b/tensorflow/core/kernels/cast_op_impl_float.cc index 71e63fbff0f..1241dcd8f2e 100644 --- a/tensorflow/core/kernels/cast_op_impl_float.cc +++ b/tensorflow/core/kernels/cast_op_impl_float.cc @@ -49,4 +49,14 @@ GetGpuCastFromFloat(DataType dst_dtype) { } #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +std::function +GetSyclCastFromFloat(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, SYCLDevice, float); + return nullptr; +} +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow + diff --git a/tensorflow/core/kernels/cast_op_impl_int32.cc b/tensorflow/core/kernels/cast_op_impl_int32.cc index 0fc6e16afea..fca9cd60ec1 100644 --- a/tensorflow/core/kernels/cast_op_impl_int32.cc +++ b/tensorflow/core/kernels/cast_op_impl_int32.cc @@ -34,4 +34,14 @@ GetGpuCastFromInt32(DataType dst_dtype) { } #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +std::function +GetSyclCastFromInt32(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, CPUDevice, int32); + return nullptr; +} +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow + diff --git a/tensorflow/core/kernels/cast_op_impl_int64.cc b/tensorflow/core/kernels/cast_op_impl_int64.cc index b5571b19a5d..c0a543708d3 100644 --- a/tensorflow/core/kernels/cast_op_impl_int64.cc +++ b/tensorflow/core/kernels/cast_op_impl_int64.cc @@ -19,6 +19,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL std::function GetCpuCastFromInt64(DataType dst_dtype) { @@ -34,4 +37,13 @@ GetGpuCastFromInt64(DataType dst_dtype) { } #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +std::function +GetSyclCastFromInt64(DataType dst_dtype) { + CURRY_TYPES3(CAST_CASE, SYCLDevice, int64); + return nullptr; +} +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cast_op_test.cc b/tensorflow/core/kernels/cast_op_test.cc index 5b7529bb8a9..a106f287c18 100644 --- a/tensorflow/core/kernels/cast_op_test.cc +++ b/tensorflow/core/kernels/cast_op_test.cc @@ -105,7 +105,12 @@ static void BM_gpu_float_int64(int iters, int num) { testing::BytesProcessed(static_cast(iters) * num * (sizeof(float) + sizeof(int64))); testing::UseRealTime(); +#if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + test::Benchmark("sycl", Cast(num)).Run(iters); +#endif // TENSORFLOW_USE_SYCL } BENCHMARK(BM_gpu_float_int64)->Arg(64 << 10)->Arg(32 << 20); @@ -123,7 +128,12 @@ static void BM_gpu_bool_float(int iters, int num) { testing::BytesProcessed(static_cast(iters) * num * (sizeof(bool) + sizeof(float))); testing::UseRealTime(); +#if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + test::Benchmark("sycl", Cast(num)).Run(iters); +#endif // TENSORFLOW_USE_SYCL } BENCHMARK(BM_gpu_bool_float)->Arg(64 << 10)->Arg(32 << 20); @@ -168,7 +178,9 @@ static void BM_gpu_float_half(int iters, int num) { testing::BytesProcessed(static_cast(iters) * num * (sizeof(float) + sizeof(Eigen::half))); testing::UseRealTime(); +#if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); +#endif // GOOGLE_CUDA } BENCHMARK(BM_gpu_float_half)->Arg(64 << 10)->Arg(32 << 20); @@ -177,7 +189,9 @@ static void BM_gpu_half_float(int iters, int num) { testing::BytesProcessed(static_cast(iters) * num * (sizeof(float) + sizeof(Eigen::half))); testing::UseRealTime(); +#if GOOGLE_CUDA test::Benchmark("gpu", Cast(num)).Run(iters); +#endif // GOOGLE_CUDA } BENCHMARK(BM_gpu_half_float)->Arg(64 << 10)->Arg(32 << 20); diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h index cef873f804a..14e6e1bc324 100644 --- a/tensorflow/core/kernels/concat_lib.h +++ b/tensorflow/core/kernels/concat_lib.h @@ -38,6 +38,14 @@ void ConcatGPU( Tensor* output, typename TTypes::Tensor* output_flat); #endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +template +void ConcatSYCL(const Eigen::SyclDevice& d, + const std::vector< + std::unique_ptr::ConstMatrix>>& inputs, + typename TTypes::Matrix* output); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow #endif // TENSORFLOW_KERNELS_CONCAT_LIB_H_ diff --git a/tensorflow/core/kernels/concat_lib_cpu.cc b/tensorflow/core/kernels/concat_lib_cpu.cc index f83aed6aefd..f89948350c3 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.cc +++ b/tensorflow/core/kernels/concat_lib_cpu.cc @@ -74,4 +74,23 @@ REGISTER(qint16) REGISTER(qint32) REGISTER(bfloat16) +#ifdef TENSORFLOW_USE_SYCL +template +void ConcatSYCL(const Eigen::SyclDevice& d, + const std::vector< + std::unique_ptr::ConstMatrix>>& inputs, + typename TTypes::Matrix* output) { + ConcatSYCLImpl(d, inputs, sizeof(T) /* cost_per_unit */, MemCpyCopier(), + output); +} +#define REGISTER_SYCL(T) \ + template void ConcatSYCL( \ + const Eigen::SyclDevice&, \ + const std::vector::ConstMatrix>>&, \ + typename TTypes::Matrix* output); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL) + +#undef REGISTER_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h index 9d37cafb4ed..6a933efde4b 100644 --- a/tensorflow/core/kernels/concat_lib_cpu.h +++ b/tensorflow/core/kernels/concat_lib_cpu.h @@ -126,4 +126,39 @@ void ConcatCPUImpl( cost_per_unit, work); } +#ifdef TENSORFLOW_USE_SYCL +template +void ConcatSYCLImpl( + const Eigen::SyclDevice& d, + const std::vector::ConstMatrix>>& + inputs, + int64 cost_per_unit, ElementCopier copier, + typename TTypes::Matrix* output) { + size_t num_inputs = inputs.size(); + + std::vector sizes; + sizes.reserve(num_inputs); + int64 row_size = 0; + for (const auto& input : inputs) { + sizes.push_back(input->dimension(1)); + row_size += sizes.back(); + } + + T* out = &(*output)(0, 0); + std::vector inp; + inp.reserve(num_inputs); + for (const auto& input : inputs) { + inp.push_back(&(*input)(0, 0)); + } + const int64 dim0 = output->dimension(0); + for (int64 i = 0; i < dim0; ++i) { + for (int64 j = 0; j < num_inputs; ++j) { + auto size = sizes[j]; + d.memcpy(out, inp[j], size * sizeof(T)); + out += size; + inp[j] += size; + } + } +} +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc index e6dae5fa7eb..9628a7efa4b 100644 --- a/tensorflow/core/kernels/concat_op.cc +++ b/tensorflow/core/kernels/concat_op.cc @@ -35,6 +35,9 @@ typedef Eigen::ThreadPoolDevice CPUDevice; #if GOOGLE_CUDA typedef Eigen::GpuDevice GPUDevice; #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM }; @@ -134,6 +137,12 @@ class ConcatBaseOp : public OpKernel { return; } #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL + if (std::is_same::value) { + ConcatSYCL(c->eigen_sycl_device(), inputs_flat, &output_flat); + return; + } +#endif // TENSORFLOW_USE_SYCL ConcatCPU(c->device(), inputs_flat, &output_flat); } } @@ -207,6 +216,39 @@ REGISTER_KERNEL_BUILDER(Name("ConcatV2") #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("Concat") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("concat_dim"), \ + ConcatOp) \ + REGISTER_KERNEL_BUILDER(Name("ConcatV2") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("axis"), \ + ConcatV2Op) + +TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL); +REGISTER_KERNEL_BUILDER(Name("Concat") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .HostMemory("concat_dim") + .HostMemory("values") + .HostMemory("output"), + ConcatOp); +REGISTER_KERNEL_BUILDER(Name("ConcatV2") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tidx") + .HostMemory("values") + .HostMemory("axis") + .HostMemory("output"), + ConcatV2Op); +#undef REGISTER_SYCL +#endif // TENSORFLOW_USE_SYCL + class ConcatOffsetOp : public OpKernel { public: explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -293,4 +335,12 @@ REGISTER_KERNEL_BUILDER(Name("ConcatOffset") .HostMemory("offset"), ConcatOffsetOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ConcatOffset") + .Device(DEVICE_SYCL) + .HostMemory("concat_dim") + .HostMemory("shape") + .HostMemory("offset"), + ConcatOffsetOp); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 306736fe544..a0f89f2abd6 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -57,7 +57,10 @@ REGISTER_KERNEL_BUILDER(Name("Const").Device(DEVICE_CPU), ConstantOp); REGISTER_KERNEL_BUILDER( \ Name("Const").Device(DEVICE_SYCL).TypeConstraint("dtype"), \ ConstantOp); -TF_CALL_NUMBER_TYPES(REGISTER_SYCL_KERNEL); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +REGISTER_SYCL_KERNEL(bool); +REGISTER_SYCL_KERNEL(int64); #undef REGISTER_SYCL_KERNEL #endif @@ -112,6 +115,17 @@ REGISTER_KERNEL_BUILDER(Name("Const") HostConstantOp); #endif +#ifdef TENSORFLOW_USE_SYCL +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Const") + .Device(DEVICE_SYCL) + .HostMemory("output") + .TypeConstraint("dtype"), + HostConstantOp); +#endif // TENSORFLOW_USE_SYCL + typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL @@ -186,6 +200,7 @@ REGISTER_KERNEL(CPU, quint8); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL(SYCL, float) +REGISTER_KERNEL(SYCL, double) REGISTER_KERNEL_BUILDER(Name("Fill") .Device(DEVICE_SYCL) .TypeConstraint("T") @@ -245,6 +260,7 @@ TF_CALL_POD_STRING_TYPES(REGISTER_CPU); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL(float, SYCL); +REGISTER_KERNEL(bool, SYCL); REGISTER_KERNEL_BUILDER(Name("ZerosLike") .Device(DEVICE_SYCL) .TypeConstraint("T") diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index 1a73a3d0f8f..6a79be5a952 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -321,6 +321,30 @@ TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_REF_KERNEL); #undef REGISTER_SYCL_KERNEL #undef REGISTER_SYCL_REF_KERNEL +#define REGISTER_SYCL_HOST_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Enter") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + EnterOp) + +#define REGISTER_SYCL_HOST_REF_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("RefEnter") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint("T"), \ + EnterOp) + +REGISTER_SYCL_HOST_KERNEL(int32); +REGISTER_SYCL_HOST_REF_KERNEL(int32); +REGISTER_SYCL_HOST_KERNEL(string); +REGISTER_SYCL_HOST_REF_KERNEL(string); +REGISTER_SYCL_HOST_KERNEL(ResourceHandle); + +#undef REGISTER_SYCL_HOST_KERNEL +#undef REGISTER_SYCL_HOST_REF_KERNEL #endif // Special GPU kernels for int32 and string. diff --git a/tensorflow/core/kernels/cwise_op_acos.cc b/tensorflow/core/kernels/cwise_op_acos.cc index 1d2d815027f..65801da3c7c 100644 --- a/tensorflow/core/kernels/cwise_op_acos.cc +++ b/tensorflow/core/kernels/cwise_op_acos.cc @@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Acos", functor::acos, float, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_add_1.cc b/tensorflow/core/kernels/cwise_op_add_1.cc index a6bff78694a..f6e9b59cf8d 100644 --- a/tensorflow/core/kernels/cwise_op_add_1.cc +++ b/tensorflow/core/kernels/cwise_op_add_1.cc @@ -18,7 +18,7 @@ limitations under the License. namespace tensorflow { REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32, int64); - + #if TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(TYPE) \ REGISTER_KERNEL_BUILDER( \ @@ -26,10 +26,19 @@ REGISTER5(BinaryOp, CPU, "Add", functor::add, float, Eigen::half, double, int32, .Device(DEVICE_SYCL) \ .TypeConstraint("T"), \ BinaryOp>); - REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL + +REGISTER_KERNEL_BUILDER(Name("Add") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); #endif // TENSORFLOW_USE_SYCL - + #if GOOGLE_CUDA REGISTER3(BinaryOp, GPU, "Add", functor::add, float, Eigen::half, double); diff --git a/tensorflow/core/kernels/cwise_op_asin.cc b/tensorflow/core/kernels/cwise_op_asin.cc index 92a22e90c4a..c9ebfe759b1 100644 --- a/tensorflow/core/kernels/cwise_op_asin.cc +++ b/tensorflow/core/kernels/cwise_op_asin.cc @@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Asin", functor::asin, float, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_atan.cc b/tensorflow/core/kernels/cwise_op_atan.cc index 825e85283f4..72645b303fc 100644 --- a/tensorflow/core/kernels/cwise_op_atan.cc +++ b/tensorflow/core/kernels/cwise_op_atan.cc @@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Atan", functor::atan, float, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_ceil.cc b/tensorflow/core/kernels/cwise_op_ceil.cc index c5a4aaf831f..c74e10576d5 100644 --- a/tensorflow/core/kernels/cwise_op_ceil.cc +++ b/tensorflow/core/kernels/cwise_op_ceil.cc @@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "Ceil", functor::ceil, float, Eigen::half, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_cos.cc b/tensorflow/core/kernels/cwise_op_cos.cc index a758da58421..634c90adc63 100644 --- a/tensorflow/core/kernels/cwise_op_cos.cc +++ b/tensorflow/core/kernels/cwise_op_cos.cc @@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Cos", functor::cos, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc index 74d8faedb5e..1e2300832fc 100644 --- a/tensorflow/core/kernels/cwise_op_div.cc +++ b/tensorflow/core/kernels/cwise_op_div.cc @@ -37,8 +37,18 @@ REGISTER5(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double, .TypeConstraint("T"), \ BinaryOp>); REGISTER_SYCL_KERNEL(float) -REGISTER_SYCL_KERNEL(int32) +REGISTER_SYCL_KERNEL(double) #undef REGISTER_SYCL_KERNEL +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Div") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8, diff --git a/tensorflow/core/kernels/cwise_op_equal_to_1.cc b/tensorflow/core/kernels/cwise_op_equal_to_1.cc index 7bd44abd393..93ea768836f 100644 --- a/tensorflow/core/kernels/cwise_op_equal_to_1.cc +++ b/tensorflow/core/kernels/cwise_op_equal_to_1.cc @@ -34,4 +34,16 @@ REGISTER_KERNEL_BUILDER(Name("Equal") BinaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER2(BinaryOp, SYCL, "Equal", functor::equal_to, float, double); + +REGISTER_KERNEL_BUILDER(Name("Equal") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_expm1.cc b/tensorflow/core/kernels/cwise_op_expm1.cc index f1c53ca272c..5573c2bcc2f 100644 --- a/tensorflow/core/kernels/cwise_op_expm1.cc +++ b/tensorflow/core/kernels/cwise_op_expm1.cc @@ -21,4 +21,7 @@ REGISTER5(UnaryOp, CPU, "Expm1", functor::expm1, float, Eigen::half, double, #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Expm1", functor::expm1, float, Eigen::half, double); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(UnaryOp, SYCL, "Expm1", functor::expm1, float); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_floor.cc b/tensorflow/core/kernels/cwise_op_floor.cc index 129d754b826..59e32d7f6f4 100644 --- a/tensorflow/core/kernels/cwise_op_floor.cc +++ b/tensorflow/core/kernels/cwise_op_floor.cc @@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "Floor", functor::floor, float, Eigen::half, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_floor_div.cc b/tensorflow/core/kernels/cwise_op_floor_div.cc index 8a600f8f95e..fa81ef0872d 100644 --- a/tensorflow/core/kernels/cwise_op_floor_div.cc +++ b/tensorflow/core/kernels/cwise_op_floor_div.cc @@ -21,17 +21,6 @@ REGISTER5(BinaryOp, CPU, "FloorDiv", functor::safe_floor_div, uint8, uint16, REGISTER3(BinaryOp, CPU, "FloorDiv", functor::floor_div_real, float, Eigen::half, double); -#if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("FloorDiv") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint("T"), \ - BinaryOp>); -REGISTER_SYCL_KERNEL(float) -#undef REGISTER_SYCL_KERNEL -#endif // TENSORFLOW_USE_SYCL - #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "FloorDiv", functor::floor_div, uint8, uint16, int16, int64); @@ -51,4 +40,14 @@ REGISTER_KERNEL_BUILDER(Name("FloorDiv") .TypeConstraint("T"), BinaryOp>); #endif + +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FloorDiv") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_floor_mod.cc b/tensorflow/core/kernels/cwise_op_floor_mod.cc index 4e641a8bb33..55f8a30461f 100644 --- a/tensorflow/core/kernels/cwise_op_floor_mod.cc +++ b/tensorflow/core/kernels/cwise_op_floor_mod.cc @@ -31,4 +31,14 @@ REGISTER_KERNEL_BUILDER(Name("FloorMod") .TypeConstraint("T"), BinaryOp>); #endif + +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FloorMod") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_greater.cc b/tensorflow/core/kernels/cwise_op_greater.cc index 8c9691d1ea2..6b5a806aa21 100644 --- a/tensorflow/core/kernels/cwise_op_greater.cc +++ b/tensorflow/core/kernels/cwise_op_greater.cc @@ -33,5 +33,19 @@ REGISTER_KERNEL_BUILDER(Name("Greater") .TypeConstraint("T"), BinaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(BinaryOp, SYCL, "Greater", functor::greater, float); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Greater") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_greater_equal.cc b/tensorflow/core/kernels/cwise_op_greater_equal.cc index a6083cb9cd5..ac215282561 100644 --- a/tensorflow/core/kernels/cwise_op_greater_equal.cc +++ b/tensorflow/core/kernels/cwise_op_greater_equal.cc @@ -34,4 +34,15 @@ REGISTER_KERNEL_BUILDER(Name("GreaterEqual") BinaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(BinaryOp, SYCL, "GreaterEqual", functor::greater_equal, float); + +REGISTER_KERNEL_BUILDER(Name("GreaterEqual") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_isfinite.cc b/tensorflow/core/kernels/cwise_op_isfinite.cc index 59976141c78..0faeffa95ca 100644 --- a/tensorflow/core/kernels/cwise_op_isfinite.cc +++ b/tensorflow/core/kernels/cwise_op_isfinite.cc @@ -27,6 +27,7 @@ REGISTER3(UnaryOp, CPU, "IsFinite", functor::isfinite, float, Eigen::half, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_isinf.cc b/tensorflow/core/kernels/cwise_op_isinf.cc index 675cb95b955..df63006b3fd 100644 --- a/tensorflow/core/kernels/cwise_op_isinf.cc +++ b/tensorflow/core/kernels/cwise_op_isinf.cc @@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "IsInf", functor::isinf, float, Eigen::half, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_isnan.cc b/tensorflow/core/kernels/cwise_op_isnan.cc index c394087ed80..e1cf7a86375 100644 --- a/tensorflow/core/kernels/cwise_op_isnan.cc +++ b/tensorflow/core/kernels/cwise_op_isnan.cc @@ -26,6 +26,7 @@ REGISTER3(UnaryOp, CPU, "IsNan", functor::isnan, float, Eigen::half, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_less.cc b/tensorflow/core/kernels/cwise_op_less.cc index 701007d6376..a38f1024a9a 100644 --- a/tensorflow/core/kernels/cwise_op_less.cc +++ b/tensorflow/core/kernels/cwise_op_less.cc @@ -33,5 +33,15 @@ REGISTER_KERNEL_BUILDER(Name("Less") .TypeConstraint("T"), BinaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER3(BinaryOp, SYCL, "Less", functor::less, float, double, int64); +REGISTER_KERNEL_BUILDER(Name("Less") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_less_equal.cc b/tensorflow/core/kernels/cwise_op_less_equal.cc index 97fd1ae9192..3a2cc2ae0e8 100644 --- a/tensorflow/core/kernels/cwise_op_less_equal.cc +++ b/tensorflow/core/kernels/cwise_op_less_equal.cc @@ -34,4 +34,16 @@ REGISTER_KERNEL_BUILDER(Name("LessEqual") BinaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(BinaryOp, SYCL, "LessEqual", functor::less_equal, float); + +REGISTER_KERNEL_BUILDER(Name("LessEqual") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_log.cc b/tensorflow/core/kernels/cwise_op_log.cc index 71c4588b3de..5e74e778c76 100644 --- a/tensorflow/core/kernels/cwise_op_log.cc +++ b/tensorflow/core/kernels/cwise_op_log.cc @@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Log", functor::log, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_log1p.cc b/tensorflow/core/kernels/cwise_op_log1p.cc index 03ea3a0a894..edb821318e8 100644 --- a/tensorflow/core/kernels/cwise_op_log1p.cc +++ b/tensorflow/core/kernels/cwise_op_log1p.cc @@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Log1p", functor::log1p, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_maximum.cc b/tensorflow/core/kernels/cwise_op_maximum.cc index f93b5a83031..7311f25ec0c 100644 --- a/tensorflow/core/kernels/cwise_op_maximum.cc +++ b/tensorflow/core/kernels/cwise_op_maximum.cc @@ -34,4 +34,19 @@ REGISTER_KERNEL_BUILDER(Name("Maximum") BinaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(BinaryOp, SYCL, "Maximum", functor::maximum, float); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Maximum") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_minimum.cc b/tensorflow/core/kernels/cwise_op_minimum.cc index 36800975a80..99e5a766203 100644 --- a/tensorflow/core/kernels/cwise_op_minimum.cc +++ b/tensorflow/core/kernels/cwise_op_minimum.cc @@ -34,4 +34,16 @@ REGISTER_KERNEL_BUILDER(Name("Minimum") BinaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(BinaryOp, SYCL, "Minimum", functor::minimum, float); + +REGISTER_KERNEL_BUILDER(Name("Minimum") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_mul_1.cc b/tensorflow/core/kernels/cwise_op_mul_1.cc index e23fe6761d7..5273522626b 100644 --- a/tensorflow/core/kernels/cwise_op_mul_1.cc +++ b/tensorflow/core/kernels/cwise_op_mul_1.cc @@ -28,7 +28,15 @@ REGISTER5(BinaryOp, CPU, "Mul", functor::mul, float, Eigen::half, double, .TypeConstraint("T"), \ BinaryOp>); REGISTER_SYCL_KERNEL(float) +REGISTER_SYCL_KERNEL(double) #undef REGISTER_SYCL_KERNEL +REGISTER_KERNEL_BUILDER(Name("Mul") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .HostMemory("z") + .TypeConstraint("T"), + BinaryOp>); #endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA REGISTER4(BinaryOp, GPU, "Mul", functor::mul, float, Eigen::half, double, diff --git a/tensorflow/core/kernels/cwise_op_pow.cc b/tensorflow/core/kernels/cwise_op_pow.cc index 8eeba6ab14f..f1780168e45 100644 --- a/tensorflow/core/kernels/cwise_op_pow.cc +++ b/tensorflow/core/kernels/cwise_op_pow.cc @@ -27,6 +27,7 @@ REGISTER7(BinaryOp, CPU, "Pow", functor::pow, float, Eigen::half, double, int32, .TypeConstraint("T"), \ BinaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_reciprocal.cc b/tensorflow/core/kernels/cwise_op_reciprocal.cc index d858a077f5c..8c0e21f9cf3 100644 --- a/tensorflow/core/kernels/cwise_op_reciprocal.cc +++ b/tensorflow/core/kernels/cwise_op_reciprocal.cc @@ -36,6 +36,9 @@ REGISTER5(UnaryOp, CPU, "Reciprocal", functor::inverse, float, Eigen::half, REGISTER4(UnaryOp, GPU, "Reciprocal", functor::inverse, float, Eigen::half, double, int64); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(UnaryOp, SYCL, "Reciprocal", functor::inverse, float); +#endif // TENSORFLOW_USE_SYCL REGISTER5(SimpleBinaryOp, CPU, "ReciprocalGrad", functor::inverse_grad, float, Eigen::half, double, complex64, complex128); @@ -43,4 +46,7 @@ REGISTER5(SimpleBinaryOp, CPU, "ReciprocalGrad", functor::inverse_grad, float, REGISTER3(SimpleBinaryOp, GPU, "ReciprocalGrad", functor::inverse_grad, float, Eigen::half, double); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(SimpleBinaryOp, SYCL, "ReciprocalGrad", functor::inverse_grad, float); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_round.cc b/tensorflow/core/kernels/cwise_op_round.cc index 7a4482dbb2b..e192f89782d 100644 --- a/tensorflow/core/kernels/cwise_op_round.cc +++ b/tensorflow/core/kernels/cwise_op_round.cc @@ -20,9 +20,9 @@ REGISTER5(UnaryOp, CPU, "Round", functor::round, Eigen::half, float, double, int32, int64); #ifdef TENSORFLOW_USE_SYCL -REGISTER(UnaryOp, SYCL, "Round", functor::round, float); +REGISTER2(UnaryOp, SYCL, "Round", functor::round, float, double); namespace functor { -DEFINE_UNARY1(round, float); +DEFINE_UNARY2(round, float, double); } // namespace functor #endif diff --git a/tensorflow/core/kernels/cwise_op_rsqrt.cc b/tensorflow/core/kernels/cwise_op_rsqrt.cc index 7dc96d47a60..f23725f48e3 100644 --- a/tensorflow/core/kernels/cwise_op_rsqrt.cc +++ b/tensorflow/core/kernels/cwise_op_rsqrt.cc @@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Rsqrt", functor::rsqrt, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc index 8160fb74c2a..b5deffdb855 100644 --- a/tensorflow/core/kernels/cwise_op_select.cc +++ b/tensorflow/core/kernels/cwise_op_select.cc @@ -28,6 +28,10 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL + template class SelectOp : public OpKernel { public: @@ -163,12 +167,24 @@ REGISTER_SELECT_GPU(complex128); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +// Registration of the SYCL implementations. +#define REGISTER_SELECT_SYCL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Select").Device(DEVICE_SYCL).TypeConstraint("T"), \ + SelectOp); + +REGISTER_SELECT_SYCL(float); +REGISTER_SELECT_SYCL(int32); +#undef REGISTER_SELECT_SYCL +#endif // TENSORFLOW_USE_SYCL + namespace functor { // CPU Specializations of Select functors. -template -struct SelectFunctor { - void operator()(const CPUDevice& d, typename TTypes::Flat out, +template +struct SelectFunctorBase { + void operator()(const Device& d, typename TTypes::Flat out, typename TTypes::ConstFlat cond_flat, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat) { @@ -176,10 +192,18 @@ struct SelectFunctor { } }; -// CPU Specializations of Select functors with scalar template -struct SelectScalarFunctor { - void operator()(const CPUDevice& d, typename TTypes::Flat out, +struct SelectFunctor + : SelectFunctorBase {}; +#ifdef TENSORFLOW_USE_SYCL +template +struct SelectFunctor + : SelectFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL + +template +struct SelectScalarFunctorBase { + void operator()(const Device& d, typename TTypes::Flat out, TTypes::ConstScalar cond, typename TTypes::ConstFlat then_flat, typename TTypes::ConstFlat else_flat) { @@ -187,9 +211,19 @@ struct SelectScalarFunctor { } }; +// CPU Specializations of Select functors with scalar template -struct BatchSelectFunctor { - void operator()(const CPUDevice& d, +struct SelectScalarFunctor + : SelectScalarFunctorBase {}; +#ifdef TENSORFLOW_USE_SYCL +template +struct SelectScalarFunctor + : SelectScalarFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL + +template +struct BatchSelectFunctorBase { + void operator()(const Device& d, typename TTypes::Matrix output_flat_outer_dims, TTypes::ConstVec cond_vec, typename TTypes::ConstMatrix then_flat_outer_dims, @@ -214,6 +248,15 @@ struct BatchSelectFunctor { } }; +template +struct BatchSelectFunctor + : BatchSelectFunctorBase {}; +#ifdef TENSORFLOW_USE_SYCL +template +struct BatchSelectFunctor + : BatchSelectFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sigmoid.cc b/tensorflow/core/kernels/cwise_op_sigmoid.cc index cc1f9b8f03e..a76a088ac8f 100644 --- a/tensorflow/core/kernels/cwise_op_sigmoid.cc +++ b/tensorflow/core/kernels/cwise_op_sigmoid.cc @@ -23,6 +23,9 @@ REGISTER5(UnaryOp, CPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double, REGISTER3(UnaryOp, GPU, "Sigmoid", functor::sigmoid, float, Eigen::half, double); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(UnaryOp, SYCL, "Sigmoid", functor::sigmoid, float); +#endif // TENSORFLOW_USE_SYCL REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float, Eigen::half, double, complex64, complex128); @@ -30,5 +33,8 @@ REGISTER5(SimpleBinaryOp, CPU, "SigmoidGrad", functor::sigmoid_grad, float, REGISTER3(SimpleBinaryOp, GPU, "SigmoidGrad", functor::sigmoid_grad, float, Eigen::half, double); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(SimpleBinaryOp, SYCL, "SigmoidGrad", functor::sigmoid_grad, float); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sign.cc b/tensorflow/core/kernels/cwise_op_sign.cc index 568906612a6..dedd414db55 100644 --- a/tensorflow/core/kernels/cwise_op_sign.cc +++ b/tensorflow/core/kernels/cwise_op_sign.cc @@ -33,4 +33,17 @@ REGISTER_KERNEL_BUILDER(Name("Sign") UnaryOp>); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER(UnaryOp, SYCL, "Sign", functor::sign, float); +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Sign") + .Device(DEVICE_SYCL) + .HostMemory("x") + .HostMemory("y") + .TypeConstraint("T"), + UnaryOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_sin.cc b/tensorflow/core/kernels/cwise_op_sin.cc index 8d0c0959f74..ab54c61b56d 100644 --- a/tensorflow/core/kernels/cwise_op_sin.cc +++ b/tensorflow/core/kernels/cwise_op_sin.cc @@ -27,6 +27,7 @@ REGISTER5(UnaryOp, CPU, "Sin", functor::sin, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYC diff --git a/tensorflow/core/kernels/cwise_op_sqrt.cc b/tensorflow/core/kernels/cwise_op_sqrt.cc index 710001517b5..55acf648db0 100644 --- a/tensorflow/core/kernels/cwise_op_sqrt.cc +++ b/tensorflow/core/kernels/cwise_op_sqrt.cc @@ -27,8 +27,9 @@ REGISTER5(UnaryOp, CPU, "Sqrt", functor::sqrt, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL -#endif // TENSORFLOW_USE_SYC +#endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA REGISTER3(UnaryOp, GPU, "Sqrt", functor::sqrt, float, Eigen::half, double); diff --git a/tensorflow/core/kernels/cwise_op_square.cc b/tensorflow/core/kernels/cwise_op_square.cc index f867f127a72..afcacfec1c7 100644 --- a/tensorflow/core/kernels/cwise_op_square.cc +++ b/tensorflow/core/kernels/cwise_op_square.cc @@ -27,6 +27,7 @@ REGISTER7(UnaryOp, CPU, "Square", functor::square, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYC diff --git a/tensorflow/core/kernels/cwise_op_tan.cc b/tensorflow/core/kernels/cwise_op_tan.cc index ac49cad88fd..9c850c94207 100644 --- a/tensorflow/core/kernels/cwise_op_tan.cc +++ b/tensorflow/core/kernels/cwise_op_tan.cc @@ -26,6 +26,7 @@ REGISTER2(UnaryOp, CPU, "Tan", functor::tan, float, double); .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYC diff --git a/tensorflow/core/kernels/cwise_op_tanh.cc b/tensorflow/core/kernels/cwise_op_tanh.cc index ae2c473e20b..1dbc13061ba 100644 --- a/tensorflow/core/kernels/cwise_op_tanh.cc +++ b/tensorflow/core/kernels/cwise_op_tanh.cc @@ -28,6 +28,7 @@ REGISTER5(UnaryOp, CPU, "Tanh", functor::tanh, float, Eigen::half, double, .TypeConstraint("T"), \ UnaryOp>); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif // TENSORFLOW_USE_SYC diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h index 671de380d37..77b330f5899 100644 --- a/tensorflow/core/kernels/cwise_ops_gradients.h +++ b/tensorflow/core/kernels/cwise_ops_gradients.h @@ -171,6 +171,21 @@ struct SimpleBinaryFunctor { } }; + +#ifdef TENSORFLOW_USE_SYCL +// Partial specialization of BinaryFunctor for SYCL devices +typedef Eigen::SyclDevice SYCLDevice; +template +struct SimpleBinaryFunctor { + void operator()(const SYCLDevice& d, typename Functor::tout_type out, + typename Functor::tin_type in0, + typename Functor::tin_type in1) { + out.device(d) = in0.binaryExpr(in1, typename Functor::func()); + } +}; + +#endif // TENSORFLOW_USE_SYCL + template struct tanh_grad : base> {}; diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc index 6250928aca1..92018ec8718 100644 --- a/tensorflow/core/kernels/cwise_ops_test.cc +++ b/tensorflow/core/kernels/cwise_ops_test.cc @@ -51,18 +51,38 @@ static int ColsFromArg(int arg) { return (arg % kRows); } BENCHMARK(BM_##DEVICE##_##FUNC##_##TYPE)->Range(4 << 10, 1 << 20); BM_UNARY(cpu, Floor, float, DT_FLOAT); +#if GOOGLE_CUDA BM_UNARY(gpu, Floor, float, DT_FLOAT); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_UNARY(sycl, Floor, float, DT_FLOAT); +#endif // TENSORFLOW_USE_SYCL + BM_UNARY(cpu, Floor, double, DT_DOUBLE); +#if GOOGLE_CUDA BM_UNARY(gpu, Floor, double, DT_DOUBLE); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_UNARY(sycl, Floor, double, DT_DOUBLE); +#endif // TENSORFLOW_USE_SYCL + BM_UNARY(cpu, Conj, std::complex, DT_COMPLEX64); +#if GOOGLE_CUDA BM_UNARY(gpu, Conj, std::complex, DT_COMPLEX64); +#endif // GOOGLE_CUDA BM_UNARY(cpu, Conj, std::complex, DT_COMPLEX128); +#if GOOGLE_CUDA BM_UNARY(gpu, Conj, std::complex, DT_COMPLEX128); +#endif // GOOGLE_CUDA BM_UNARY(cpu, Rint, double, DT_DOUBLE); +#if GOOGLE_CUDA BM_UNARY(gpu, Rint, double, DT_DOUBLE); +#endif // GOOGLE_CUDA BM_UNARY(cpu, Rint, float, DT_FLOAT); +#if GOOGLE_CUDA BM_UNARY(gpu, Rint, float, DT_FLOAT); +#endif // GOOGLE_CUDA // data func scalar. static Graph* BinaryScalar(int num, const string& func) { @@ -90,9 +110,20 @@ static Graph* BinaryScalar(int num, const string& func) { ->Arg(1048576); BM_BINARY_SCALAR(cpu, Less); +#if GOOGLE_CUDA BM_BINARY_SCALAR(gpu, Less); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_BINARY_SCALAR(sycl, Less); +#endif // TENSORFLOW_USE_SYCL + BM_BINARY_SCALAR(cpu, Add); +#if GOOGLE_CUDA BM_BINARY_SCALAR(gpu, Add); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_BINARY_SCALAR(sycl, Add); +#endif // TENSORFLOW_USE_SYCL #undef BM_BINARY_SCALAR template @@ -130,9 +161,13 @@ static Graph* BiasAdd(int rows, int cols, DataType type) { using Eigen::half; BM_BIAS_ADD_ALL(cpu, float, DT_FLOAT); +#if GOOGLE_CUDA BM_BIAS_ADD_ALL(gpu, float, DT_FLOAT); +#endif // GOOGLE_CUDA BM_BIAS_ADD_ALL(cpu, half, DT_HALF); +#if GOOGLE_CUDA BM_BIAS_ADD_ALL(gpu, half, DT_HALF); +#endif // GOOGLE_CUDA #undef BM_BIAS_ADD_ALL #undef BM_BIAS_ADD @@ -180,12 +215,18 @@ static Graph* BiasAddGrad(int rows, int cols, int channels, DataType type, BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 4096, 4096, 1); using Eigen::half; +#if GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, float, DT_FLOAT); BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, half, DT_HALF); +#endif // GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, float, DT_FLOAT); +#if GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, float, DT_FLOAT); +#endif // GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, half, DT_HALF); +#if GOOGLE_CUDA BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF); +#endif // GOOGLE_CUDA #undef BM_BIAS_ADD_GRAD_ALL #undef BM_BIAS_ADD_GRAD @@ -223,7 +264,12 @@ static Graph* BcastAdd(int rows, int cols, int dim) { BM_BCAST_ADD_ROW(DEVICE, 2048, 512); \ BM_BCAST_ADD_ROW(DEVICE, 4096, 512); BM_BCAST_ADD_ROW_ALL(cpu); +#if GOOGLE_CUDA BM_BCAST_ADD_ROW_ALL(gpu); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_BCAST_ADD_ROW_ALL(sycl); +#endif // TENSORFLOW_USE_SYCL #undef BM_BCAST_ADD_ROW_ALL #undef BM_BCAST_ADD_ROW @@ -244,7 +290,12 @@ BM_BCAST_ADD_ROW_ALL(gpu); BM_BCAST_ADD_COL(DEVICE, 2048, 512); \ BM_BCAST_ADD_COL(DEVICE, 4096, 512); BM_BCAST_ADD_COL_ALL(cpu); +#if GOOGLE_CUDA BM_BCAST_ADD_COL_ALL(gpu); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_BCAST_ADD_COL_ALL(sycl); +#endif // TENSORFLOW_USE_SYCL #undef BM_BCAST_ADD_COL_ALL #undef BM_BCAST_ADD_COL diff --git a/tensorflow/core/kernels/debug_ops.cc b/tensorflow/core/kernels/debug_ops.cc index 0706b72a895..d0f5db3bf2c 100644 --- a/tensorflow/core/kernels/debug_ops.cc +++ b/tensorflow/core/kernels/debug_ops.cc @@ -97,6 +97,7 @@ REGISTER_GPU_DEBUG_NAN_COUNT(double); .TypeConstraint("T"), \ DebugNanCountOp); REGISTER_GPU_DEBUG_NAN_COUNT(float); +REGISTER_GPU_DEBUG_NAN_COUNT(double); #endif // TENSORFLOW_USE_SYCL // Register debug numeric summary ops. @@ -129,6 +130,7 @@ REGISTER_GPU_DEBUG_NUMERIC_SUMMARY_COUNT(double); .TypeConstraint("T"), \ DebugNumericSummaryOp); REGISTER_GPU_DEBUG_NUMERIC_SUMMARY_COUNT(float); +REGISTER_GPU_DEBUG_NUMERIC_SUMMARY_COUNT(double); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index 42fe6e88c91..767f143727c 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -152,6 +152,7 @@ typedef Eigen::SyclDevice SYCLDevice; DenseUpdateOp); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL #endif diff --git a/tensorflow/core/kernels/fill_functor.cc b/tensorflow/core/kernels/fill_functor.cc index 08ec4baff3e..0df8f9d3edf 100644 --- a/tensorflow/core/kernels/fill_functor.cc +++ b/tensorflow/core/kernels/fill_functor.cc @@ -62,6 +62,8 @@ void SetZeroFunctor::operator()( #define DEFINE_SETZERO_SYCL(T) \ template struct SetZeroFunctor; DEFINE_SETZERO_SYCL(float); +DEFINE_SETZERO_SYCL(bool); +DEFINE_SETZERO_SYCL(double); #undef DEFINE_SETZERO_SYCL #endif // TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 9aa289c3c95..d08dec46d19 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -185,6 +185,34 @@ REGISTER_KERNEL_BUILDER(Name("_ArrayToList") .TypeConstraint("T"), PassOn); +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_ListToArray").Device(DEVICE_SYCL).TypeConstraint("T"),\ + PassOn); \ + REGISTER_KERNEL_BUILDER( \ + Name("_ArrayToList").Device(DEVICE_SYCL).TypeConstraint("T"),\ + PassOn); + +REGISTER_SYCL_KERNELS(float); +REGISTER_SYCL_KERNELS(double); + +#undef REGISTER_SYCL_KERNELS + +REGISTER_KERNEL_BUILDER(Name("_ListToArray") + .Device(DEVICE_SYCL) + .HostMemory("input") + .HostMemory("output") + .TypeConstraint("T"), + PassOn); +REGISTER_KERNEL_BUILDER(Name("_ArrayToList") + .Device(DEVICE_SYCL) + .HostMemory("input") + .HostMemory("output") + .TypeConstraint("T"), + PassOn); +#endif // TENSORFLOW_USE_SYCL + class SymbolicGradientOp : public AsyncOpKernel { public: SymbolicGradientOp(OpKernelConstruction* ctx) diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index a2b0127fac1..57c055885c8 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -46,6 +46,9 @@ perftools::gputools::DeviceMemory AsDeviceMemory(const T* cuda_memory) { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template struct LaunchMatMul; @@ -118,27 +121,42 @@ bool ExplicitVectorMatrixOptimization( return false; } -// On CPUs, we ignore USE_CUBLAS -template -struct LaunchMatMulCPU { +template +struct LaunchMatMulBase { static void launch( OpKernelContext* ctx, OpKernel* kernel, const Tensor& a, const Tensor& b, const Eigen::array, 1>& dim_pair, Tensor* out) { +#ifndef TENSORFLOW_USE_SYCL // An explicit vector-matrix multiply is much better optimized than an // implicit one and this is a bottleneck during non-batched inference. bool was_vector = ExplicitVectorMatrixOptimization(a, b, dim_pair, out); if (!was_vector) { - functor::MatMulFunctor()(ctx->eigen_device(), +#endif // TENSORFLOW_USE_SYCL + functor::MatMulFunctor()(ctx->eigen_device(), out->matrix(), a.matrix(), b.matrix(), dim_pair); +#ifndef TENSORFLOW_USE_SYCL } +#endif // TENSORFLOW_USE_SYCL } }; +// On CPUs, we ignore USE_CUBLAS +template +struct LaunchMatMulCPU : LaunchMatMulBase {}; + template struct LaunchMatMul : public LaunchMatMulCPU {}; +#ifdef TENSORFLOW_USE_SYCL +template +struct LaunchMatMulSYCL : LaunchMatMulBase {}; + +template +struct LaunchMatMul : public LaunchMatMulSYCL {}; +#endif // TENSORFLOW_USE_SYCL + #if GOOGLE_CUDA template @@ -256,6 +274,20 @@ struct MatMulFunctor { } }; +#ifdef TENSORFLOW_USE_SYCL +// Partial specialization MatMulFunctor. +template +struct MatMulFunctor { + void operator()( + const SYCLDevice& d, typename MatMulTypes::out_type out, + typename MatMulTypes::in_type in0, + typename MatMulTypes::in_type in1, + const Eigen::array, 1>& dim_pair) { + MatMul(d, out, in0, in1, dim_pair); + } +}; +#endif // TENSORFLOW_USE_SYCL + } // end namespace functor #define REGISTER_CPU(T) \ @@ -294,4 +326,17 @@ TF_CALL_half(REGISTER_GPU); #endif #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("MatMul").Device(DEVICE_SYCL).TypeConstraint("T"), \ + MatMulOp); \ + REGISTER_KERNEL_BUILDER(Name("MatMul") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .Label("eigen"), \ + MatMulOp) +TF_CALL_float(REGISTER_SYCL); + +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc index 4977ad1d7cb..a6650f369ba 100644 --- a/tensorflow/core/kernels/pack_op.cc +++ b/tensorflow/core/kernels/pack_op.cc @@ -167,6 +167,7 @@ REGISTER_KERNEL_BUILDER(Name("Pack") PackOp) REGISTER_SYCL(float); +REGISTER_SYCL(double); #undef REGISTER_SYCL // A special GPU kernel for int32. diff --git a/tensorflow/core/kernels/pad_op.cc b/tensorflow/core/kernels/pad_op.cc index bec2d02cb5a..91984319c60 100644 --- a/tensorflow/core/kernels/pad_op.cc +++ b/tensorflow/core/kernels/pad_op.cc @@ -38,6 +38,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template class PadOp : public OpKernel { @@ -199,4 +202,30 @@ REGISTER_KERNEL_BUILDER(Name("Pad") PadOp); #endif +#ifdef TENSORFLOW_USE_SYCL +// Registration of the GPU implementations. +#define REGISTER_SYCL_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("Pad") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tpaddings") \ + .HostMemory("paddings"), \ + PadOp) + +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Pad") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tpaddings") + .HostMemory("input") + .HostMemory("paddings") + .HostMemory("output"), + PadOp); +#endif // TENSORFLOW_USE_SYCL + } // end namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h index 625cea42282..19071b47f14 100644 --- a/tensorflow/core/kernels/reduction_ops_common.h +++ b/tensorflow/core/kernels/reduction_ops_common.h @@ -268,6 +268,31 @@ struct ReduceFunctor template struct ReduceFunctor : ReduceFunctorBase{}; + +template +struct ReduceFunctor > { + template + static void Reduce(const SYCLDevice& d, OUT_T out, IN_T in, + const ReductionAxes& reduction_axes, + const Eigen::internal::MeanReducer& reducer) { + typedef typename IN_T::Index Index; + // Eigen sum reductions are much faster on GPU than mean reductions: + // Simply trigger them by computing the sum of the weighted inputs. + Index num_coeffs_to_reduce = 1; + for (int i = 0; i < Eigen::internal::array_size::value; + ++i) { + num_coeffs_to_reduce *= in.dimension(reduction_axes[i]); + } + T scale = T(1.0) / num_coeffs_to_reduce; + out.device(d) = (in * scale).sum(reduction_axes); + } + + template + static void FillIdentity(const SYCLDevice& d, OUT_T out, + const Eigen::internal::MeanReducer& reducer) { + FillIdentityEigenImpl(d, out, reducer); + } +}; #endif // TENSORFLOW_USE_SYCL } // namespace functor diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc index db86157c8ee..5ab97d1eeec 100644 --- a/tensorflow/core/kernels/reduction_ops_max.cc +++ b/tensorflow/core/kernels/reduction_ops_max.cc @@ -57,4 +57,27 @@ REGISTER_KERNEL_BUILDER( #endif +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Max") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_SYCL_KERNELS(float); +#undef REGISTER_SYCL_KERNELS + +REGISTER_KERNEL_BUILDER( + Name("Max") + .Device(DEVICE_SYCL) + .HostMemory("reduction_indices") + .HostMemory("input") + .HostMemory("output") + .TypeConstraint("T") + .TypeConstraint("Tidx"), + ReductionOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_mean.cc b/tensorflow/core/kernels/reduction_ops_mean.cc index fef3cd06991..e018cb55dd1 100644 --- a/tensorflow/core/kernels/reduction_ops_mean.cc +++ b/tensorflow/core/kernels/reduction_ops_mean.cc @@ -44,4 +44,17 @@ REGISTER_GPU_KERNELS(double); #endif +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Mean") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_SYCL_KERNELS(float); +#undef REGISTER_SYCL_KERNELS +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc index c362bc88674..ec240421b9a 100644 --- a/tensorflow/core/kernels/reduction_ops_min.cc +++ b/tensorflow/core/kernels/reduction_ops_min.cc @@ -57,4 +57,27 @@ REGISTER_KERNEL_BUILDER( #endif +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Min") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_SYCL_KERNELS(float); +#undef REGISTER_SYCL_KERNELS + +REGISTER_KERNEL_BUILDER( + Name("Min") + .Device(DEVICE_SYCL) + .HostMemory("reduction_indices") + .HostMemory("input") + .HostMemory("output") + .TypeConstraint("T") + .TypeConstraint("Tidx"), + ReductionOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_prod.cc b/tensorflow/core/kernels/reduction_ops_prod.cc index c6aff8c2ed4..e04c655dabb 100644 --- a/tensorflow/core/kernels/reduction_ops_prod.cc +++ b/tensorflow/core/kernels/reduction_ops_prod.cc @@ -45,4 +45,28 @@ REGISTER_GPU_KERNELS(double); #endif +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Prod") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("reduction_indices"), \ + ReductionOp>); +REGISTER_SYCL_KERNELS(float); +REGISTER_SYCL_KERNELS(double); +#undef REGISTER_SYCL_KERNELS + +REGISTER_KERNEL_BUILDER( + Name("Prod") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tidx") + .HostMemory("input") + .HostMemory("output") + .HostMemory("reduction_indices"), + ReductionOp>); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc index 3aa38f418ee..938ca66a0cb 100644 --- a/tensorflow/core/kernels/reduction_ops_sum.cc +++ b/tensorflow/core/kernels/reduction_ops_sum.cc @@ -74,7 +74,6 @@ REGISTER_KERNEL_BUILDER( .HostMemory("reduction_indices"), \ ReductionOp>); REGISTER_SYCL_KERNELS(float); -REGISTER_SYCL_KERNELS(double); #undef REGISTER_SYCL_KERNELS // A special GPU kernel for int32. diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc index f24a71ec8ca..d70398bea57 100644 --- a/tensorflow/core/kernels/relu_op.cc +++ b/tensorflow/core/kernels/relu_op.cc @@ -29,6 +29,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL #define REGISTER_RELU_KERNELS(type) \ REGISTER_KERNEL_BUILDER( \ @@ -131,4 +134,30 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +// Registration of the GPU implementations. +#define REGISTER_SYCL_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + ReluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("ReluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + ReluGradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6").Device(DEVICE_SYCL).TypeConstraint("T"), \ + Relu6Op); \ + REGISTER_KERNEL_BUILDER( \ + Name("Relu6Grad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + Relu6GradOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("Elu").Device(DEVICE_SYCL).TypeConstraint("T"), \ + EluOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("EluGrad").Device(DEVICE_SYCL).TypeConstraint("T"), \ + EluGradOp) + +REGISTER_SYCL_KERNELS(float); +#undef REGISTER_SYCL_KERNELS +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h index 365c6201a54..e2e0bd48dd1 100644 --- a/tensorflow/core/kernels/relu_op.h +++ b/tensorflow/core/kernels/relu_op.h @@ -175,6 +175,10 @@ void EluGradOp::OperateNoTemplate(OpKernelContext* context, } // namespace tensorflow +#ifdef TENSORFLOW_USE_SYCL +#undef EIGEN_USE_SYCL +#endif // TENSORFLOW_USE_SYCL + #undef EIGEN_USE_THREADS #endif // TENSORFLOW_KERNELS_RELU_OP_H_ diff --git a/tensorflow/core/kernels/reverse_op.cc b/tensorflow/core/kernels/reverse_op.cc index 7852499965c..596dac9087a 100644 --- a/tensorflow/core/kernels/reverse_op.cc +++ b/tensorflow/core/kernels/reverse_op.cc @@ -33,6 +33,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL namespace { @@ -351,4 +354,36 @@ REGISTER_KERNEL_BUILDER(Name("ReverseV2") ReverseV2Op); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNELS(T) \ + REGISTER_KERNEL_BUILDER(Name("Reverse") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("dims"), \ + ReverseOp) \ + REGISTER_KERNEL_BUILDER(Name("ReverseV2") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tidx") \ + .HostMemory("axis"), \ + ReverseV2Op) +TF_CALL_float(REGISTER_SYCL_KERNELS); + +REGISTER_KERNEL_BUILDER(Name("Reverse") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .HostMemory("tensor") + .HostMemory("dims") + .HostMemory("output"), + ReverseOp); +REGISTER_KERNEL_BUILDER(Name("ReverseV2") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tidx") + .HostMemory("tensor") + .HostMemory("axis") + .HostMemory("output"), + ReverseV2Op); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 827eb7dbca7..51dad49cfec 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -180,8 +180,8 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); -REGISTER_SCATTER_ARITHEMTIC_SYCL(float); -REGISTER_SCATTER_UPDATE_SYCL(float); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHEMTIC_SYCL); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_SYCL); #undef REGISTER_SCATTER_ARITHEMTIC_SYCL #undef REGISTER_SCATTER_UPDATE_SYCL diff --git a/tensorflow/core/kernels/sequence_ops.cc b/tensorflow/core/kernels/sequence_ops.cc index c24ecdf8b97..c8ea9230201 100644 --- a/tensorflow/core/kernels/sequence_ops.cc +++ b/tensorflow/core/kernels/sequence_ops.cc @@ -92,9 +92,11 @@ class RangeOp : public OpKernel { #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T) TF_CALL_float(REGISTER_SYCL_KERNEL); +TF_CALL_double(REGISTER_SYCL_KERNEL); TF_CALL_int32(REGISTER_SYCL_KERNEL); TF_CALL_int64(REGISTER_SYCL_KERNEL); -#endif // TENSORFLOW_USE_SYCL +#undef REGISTER_SYCL_KERNEL +#endif // TENSORFLOW_USE_SYCL TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); @@ -170,4 +172,9 @@ TF_CALL_double(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_GPU_KERNEL); TF_CALL_double(REGISTER_GPU_KERNEL); +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T) +TF_CALL_float(REGISTER_SYCL_KERNEL); +TF_CALL_double(REGISTER_SYCL_KERNEL); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc index 6bc0b4560b7..177a32464ba 100644 --- a/tensorflow/core/kernels/shape_ops.cc +++ b/tensorflow/core/kernels/shape_ops.cc @@ -201,6 +201,7 @@ REGISTER_KERNEL_BUILDER(Name("Rank").Device(DEVICE_CPU).HostMemory("output"), .HostMemory("output"), \ RankOp); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL // A special GPU kernel for int32 and bool. @@ -297,6 +298,43 @@ REGISTER_KERNEL_BUILDER(Name("Size") SizeOp); #endif +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("out_type") \ + .HostMemory("output"), \ + SizeOp); \ + REGISTER_KERNEL_BUILDER(Name("Size") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("out_type") \ + .HostMemory("output"), \ + SizeOp); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +#undef REGISTER_SYCL_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Size") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_type") + .HostMemory("input") + .HostMemory("output"), + SizeOp); +REGISTER_KERNEL_BUILDER(Name("Size") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("out_type") + .HostMemory("input") + .HostMemory("output"), + SizeOp); +#endif // TENSORFLOW_USE_SYCL + // ExpandDims ------------------------------------ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .Device(DEVICE_CPU) @@ -323,7 +361,30 @@ REGISTER_KERNEL_BUILDER(Name("ExpandDims") .HostMemory("dim") .HostMemory("output"), ExpandDimsOp); -#endif +#endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("ExpandDims") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tdim") \ + .HostMemory("dim"), \ + ExpandDimsOp); +REGISTER_SYCL_KERNEL(float) +REGISTER_SYCL_KERNEL(double) + +#undef REGISTER_SYCL_KERNEL + +REGISTER_KERNEL_BUILDER(Name("ExpandDims") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tdim") + .HostMemory("input") + .HostMemory("dim") + .HostMemory("output"), + ExpandDimsOp); +#endif // TENSORFLOW_USE_SYCL // Squeeze --------------------------------------- REGISTER_KERNEL_BUILDER(Name("Squeeze").Device(DEVICE_CPU), SqueezeOp); @@ -347,4 +408,24 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze") SqueezeOp); #endif +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("Squeeze").Device(DEVICE_SYCL).TypeConstraint("T"),\ + SqueezeOp); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +#undef REGISTER_SYCL_KERNEL + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("Squeeze") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .HostMemory("input") + .HostMemory("output"), + SqueezeOp); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/softmax_op.cc b/tensorflow/core/kernels/softmax_op.cc index c7ae93852f8..de11de32f12 100644 --- a/tensorflow/core/kernels/softmax_op.cc +++ b/tensorflow/core/kernels/softmax_op.cc @@ -28,17 +28,27 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Partial specialization for a CPUDevice, that uses the Eigen implementation // from SoftmaxEigenImpl. namespace functor { -template -struct SoftmaxFunctor { - void operator()(const CPUDevice& d, typename TTypes::ConstMatrix logits, +template +struct SoftmaxFunctorBase { + void operator()(const Device& d, typename TTypes::ConstMatrix logits, typename TTypes::Matrix softmax, const bool log) { - SoftmaxEigenImpl::Compute(d, logits, softmax, log); + SoftmaxEigenImpl::Compute(d, logits, softmax, log); } }; +template +struct SoftmaxFunctor : SoftmaxFunctorBase {}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct SoftmaxFunctor : SoftmaxFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL } // namespace functor #define REGISTER_CPU(T) \ @@ -76,4 +86,10 @@ REGISTER_KERNEL_BUILDER( SoftmaxOp); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("Softmax").Device(DEVICE_SYCL).TypeConstraint("T"), + SoftmaxOp); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/stage_op.cc b/tensorflow/core/kernels/stage_op.cc index c18b992ea15..161ba892127 100644 --- a/tensorflow/core/kernels/stage_op.cc +++ b/tensorflow/core/kernels/stage_op.cc @@ -99,6 +99,9 @@ REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_SYCL), StageOp); +#endif // TENSORFLOW_USE_SYCL class UnstageOp : public OpKernel { public: @@ -126,5 +129,8 @@ REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp); #endif +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_SYCL), UnstageOp); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index a3738655098..8a3d09f1c19 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -450,4 +450,71 @@ REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") #undef REGISTER_GPU #endif // GOOGLE_CUDA + +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("StridedSlice") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides") \ + .TypeConstraint("Index"), \ + StridedSliceOp) \ + REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("shape") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides") \ + .TypeConstraint("Index"), \ + StridedSliceGradOp)\ + REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .HostMemory("begin") \ + .HostMemory("end") \ + .HostMemory("strides") \ + .TypeConstraint("Index"), \ + StridedSliceAssignOp) + +REGISTER_SYCL(float); +REGISTER_SYCL(double); + +// A special GPU kernel for int32. +// TODO(b/25387198): Also enable int32 in device memory. This kernel +// registration requires all int32 inputs and outputs to be in host memory. +REGISTER_KERNEL_BUILDER(Name("StridedSlice") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Index") + .HostMemory("input") + .HostMemory("begin") + .HostMemory("end") + .HostMemory("strides") + .HostMemory("output"), + StridedSliceOp); +REGISTER_KERNEL_BUILDER(Name("StridedSliceGrad") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Index") + .HostMemory("shape") + .HostMemory("begin") + .HostMemory("end") + .HostMemory("strides") + .HostMemory("dy") + .HostMemory("output"), + StridedSliceGradOp); +REGISTER_KERNEL_BUILDER(Name("StridedSliceAssign") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Index") + .HostMemory("ref") + .HostMemory("begin") + .HostMemory("end") + .HostMemory("strides"), + StridedSliceAssignOp) +#undef REGISTER_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h index e89d1920b9c..93cede398a6 100644 --- a/tensorflow/core/kernels/strided_slice_op_impl.h +++ b/tensorflow/core/kernels/strided_slice_op_impl.h @@ -285,6 +285,20 @@ DECLARE_FOR_N_GPU(int32); TF_CALL_ALL_TYPES(DECLARE_FOR_N_CPU); DECLARE_FOR_N_CPU(bfloat16); +#ifdef TENSORFLOW_USE_SYCL +#define PREVENT_FOR_N_SYCL(T) \ + PREVENT_INSTANTIATE(T, STRIDED_SLICE_INSTANTIATE_DIM) + +#define DECLARE_FOR_N_SYCL(T) \ + INSTANTIATE(SYCLDevice, T, STRIDED_SLICE_INSTANTIATE_DIM) + +TF_CALL_SYCL_PROXY_TYPES(PREVENT_FOR_N_SYCL); +TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N_SYCL); +DECLARE_FOR_N_SYCL(int32); + +#undef DECLARE_FOR_N_SYCL +#endif // TENSORFLOW_USE_SYCL + #undef INSTANTIATE #undef DECLARE_FOR_N_CPU #undef DECLARE_FOR_N_GPU diff --git a/tensorflow/core/kernels/tile_ops.cc b/tensorflow/core/kernels/tile_ops.cc index e55c8679e92..9822b021ebc 100644 --- a/tensorflow/core/kernels/tile_ops.cc +++ b/tensorflow/core/kernels/tile_ops.cc @@ -260,6 +260,8 @@ TF_CALL_complex128(HANDLE_TYPE_NAME_GPU); #ifdef TENSORFLOW_USE_SYCL TF_CALL_float(HANDLE_TYPE_NAME_SYCL); +TF_CALL_double(HANDLE_TYPE_NAME_SYCL); +TF_CALL_int32(HANDLE_TYPE_NAME_SYCL); #endif // TENSORFLOW_USE_SYCL #undef HANDLE_TYPE_NAME_CPU @@ -506,6 +508,16 @@ TF_CALL_complex64(HANDLE_TYPE_NAME_GPU); TF_CALL_complex128(HANDLE_TYPE_NAME_GPU); #endif // GOOGLE_CUDA +#if TENSORFLOW_USE_SYCL +#define HANDLE_TYPE_NAME_SYCL(T) \ + HANDLE_CASE_DIM(SYCLDevice, T, DataTypeToEnum::value); + +TF_CALL_float(HANDLE_TYPE_NAME_SYCL); +TF_CALL_double(HANDLE_TYPE_NAME_SYCL); +TF_CALL_int32(HANDLE_TYPE_NAME_SYCL); +#undef HANDLE_TYPE_NAME_SYCL +#endif // TENSORFLOW_USE_SYCL + #undef HANDLE_TYPE_NAME_CPU #undef HANDLE_TYPE_NAME_GPU #undef HANDLE_CASE_DIM @@ -605,6 +617,25 @@ REGISTER_KERNEL_BUILDER(Name("Tile") .TypeConstraint("Tmultiples") .HostMemory("multiples"), TileOp); +REGISTER_KERNEL_BUILDER(Name("Tile") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tmultiples") + .HostMemory("multiples"), + TileOp); + +REGISTER_KERNEL_BUILDER(Name("TileGrad") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tmultiples") + .HostMemory("multiples"), + TileGradientOp); +REGISTER_KERNEL_BUILDER(Name("TileGrad") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .TypeConstraint("Tmultiples") + .HostMemory("multiples"), + TileGradientOp); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/tile_ops_cpu_impl.h b/tensorflow/core/kernels/tile_ops_cpu_impl.h index 650c739ed59..f06cc5514c8 100644 --- a/tensorflow/core/kernels/tile_ops_cpu_impl.h +++ b/tensorflow/core/kernels/tile_ops_cpu_impl.h @@ -70,6 +70,8 @@ typedef Eigen::SyclDevice SYCLDevice; #define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM) TF_CALL_float(DEFINE_TYPE); +TF_CALL_double(DEFINE_TYPE); +TF_CALL_int32(DEFINE_TYPE); #undef DEFINE_DIM #undef DEFINE_TYPE @@ -81,6 +83,8 @@ TF_CALL_float(DEFINE_TYPE); #define DEFINE_TYPE(T) DEFINE_DIM(T, CPU_PROVIDED_IXDIM) TF_CALL_float(DEFINE_TYPE); +TF_CALL_double(DEFINE_TYPE); +TF_CALL_int32(DEFINE_TYPE); #undef DEFINE_DIM #undef DEFINE_TYPE diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 336c6b0ccc9..5c2d371430f 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -423,6 +423,7 @@ TF_CALL_double(REGISTER_CPU_KERNELS); #ifdef TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); TF_CALL_float(REGISTER_SYCL_KERNELS); +TF_CALL_double(REGISTER_SYCL_KERNELS); #undef REGISTER_SYCL_KERNELS #endif @@ -2355,6 +2356,7 @@ TF_CALL_double(REGISTER_CPU_KERNELS); #define REGISTER_SYCL_KERNELS(T) REGISTER_KERNELS(SYCL, T); TF_CALL_float(REGISTER_SYCL_KERNELS); +TF_CALL_double(REGISTER_SYCL_KERNELS); #endif #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/transpose_functor_cpu.cc b/tensorflow/core/kernels/transpose_functor_cpu.cc index 30b82f18431..3681b9a1291 100644 --- a/tensorflow/core/kernels/transpose_functor_cpu.cc +++ b/tensorflow/core/kernels/transpose_functor_cpu.cc @@ -127,6 +127,7 @@ Status DoTranspose(const SYCLDevice& d, const Tensor& in, switch (in.dtype()) { case DT_FLOAT: + case DT_DOUBLE: case DT_INT32: internal::Transpose(d, in, perm, out); break; diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 67300c1e961..4d303f01732 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -82,6 +82,15 @@ REGISTER_KERNEL_BUILDER(Name("InvertPermutation") .HostMemory("y"), InvertPermutationOp); +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("InvertPermutation") + .Device(DEVICE_SYCL) + .TypeConstraint("T") + .HostMemory("x") + .HostMemory("y"), + InvertPermutationOp); +#endif // TENSORFLOW_USE_SYCL + // output = TransposeOp(T input, T perm) takes a tensor // of type T and rank N, and a permutation of 0, 1, ..., N-1. It // shuffles the dimensions of the input tensor according to permutation. @@ -201,4 +210,24 @@ TF_CALL_POD_TYPES(REGISTER); #undef REGISTER #endif +#ifdef TENSORFLOW_USE_SYCL +Status TransposeSyclOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice perm, Tensor* out) { + typedef Eigen::SyclDevice SYCLDevice; + return ::tensorflow::DoTranspose(ctx->eigen_device(), in, perm, + out); +} +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER(Name("Transpose") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint("T") \ + .TypeConstraint("Tperm") \ + .HostMemory("perm"), \ + TransposeSyclOp); +REGISTER(float); +REGISTER(bool); +REGISTER(int32); +#undef REGISTER +#endif + } // namespace tensorflow diff --git a/tensorflow/core/kernels/transpose_op.h b/tensorflow/core/kernels/transpose_op.h index 3b209c0ccc8..5f40bcecc18 100644 --- a/tensorflow/core/kernels/transpose_op.h +++ b/tensorflow/core/kernels/transpose_op.h @@ -50,6 +50,17 @@ class TransposeGpuOp : public TransposeOp { gtl::ArraySlice perm, Tensor* out) override; }; +#ifdef TENSORFLOW_USE_SYCL +class TransposeSyclOp : public TransposeOp { + public: + explicit TransposeSyclOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {} + + protected: + Status DoTranspose(OpKernelContext* ctx, const Tensor& in, + gtl::ArraySlice perm, Tensor* out) override; +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow #endif // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_ diff --git a/tensorflow/core/kernels/unpack_op.cc b/tensorflow/core/kernels/unpack_op.cc index 2a14fa32651..e4c79ae17bb 100644 --- a/tensorflow/core/kernels/unpack_op.cc +++ b/tensorflow/core/kernels/unpack_op.cc @@ -160,6 +160,7 @@ REGISTER_KERNEL_BUILDER(Name("Unpack") UnpackOp) REGISTER_SYCL(float); +REGISTER_SYCL(double); #undef REGISTER_SYCL // A special SYCL kernel for int32. diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 34e227156d8..7a4d9dc6503 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -58,8 +58,9 @@ REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized").Device(DEVICE_CPU), IsVariableInitializedOp); REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); #undef REGISTER_SYCL_KERNEL -#endif +#endif // TENSORFLOW_USE_SYCL #if GOOGLE_CUDA // Only register 'Variable' on GPU for the subset of types also supported by diff --git a/tensorflow/core/kernels/xent_op.cc b/tensorflow/core/kernels/xent_op.cc index 639bad5f04f..56cad8e9eb1 100644 --- a/tensorflow/core/kernels/xent_op.cc +++ b/tensorflow/core/kernels/xent_op.cc @@ -28,6 +28,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL template class SoftmaxXentWithLogitsOp : public OpKernel { @@ -74,17 +77,25 @@ class SoftmaxXentWithLogitsOp : public OpKernel { // Partial specialization for a CPUDevice, that uses the Eigen implementation // from XentEigenImpl. namespace functor { -template -struct XentFunctor { - void operator()(const CPUDevice& d, typename TTypes::ConstMatrix logits, +template +struct XentFunctorBase { + void operator()(const Device& d, typename TTypes::ConstMatrix logits, typename TTypes::ConstMatrix labels, typename TTypes::Matrix scratch, typename TTypes::Vec loss, typename TTypes::Matrix backprop) { - XentEigenImpl::Compute(d, logits, labels, scratch, loss, + XentEigenImpl::Compute(d, logits, labels, scratch, loss, backprop); } }; + +template +struct XentFunctor : XentFunctorBase {}; + +#ifdef TENSORFLOW_USE_SYCL +template +struct XentFunctor : XentFunctorBase {}; +#endif // TENSORFLOW_USE_SYCL } // namespace functor #define REGISTER_CPU(T) \ @@ -111,4 +122,11 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") SoftmaxXentWithLogitsOp); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits") + .Device(DEVICE_SYCL) + .TypeConstraint("T"), + SoftmaxXentWithLogitsOp); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index 2def59ff04a..8670ca307c1 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -390,7 +390,7 @@ class TestOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("TestOpWithNoGrad").Device(DEVICE_CPU), TestOp); #ifdef TENSORFLOW_USE_SYCL REGISTER_KERNEL_BUILDER(Name("TestOpWithNoGrad").Device(DEVICE_SYCL), TestOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL TEST_F(MathGradTest, Error_Reporting) { auto x = test::AsTensor({-3.f}); @@ -707,6 +707,8 @@ TEST_F(MathGradTest, Pow) { } } +//TODO{lukeiwanski}: Implement Complex Pow for SYCL +#ifndef TENSORFLOW_USE_SYCL TEST_F(MathGradTest, ComplexPow) { auto x = test::AsTensor({0.f, 2.f, -2.f}, TensorShape({3})); auto y = test::AsTensor({2.f, 2.f, 2.f}, TensorShape({3})); @@ -725,6 +727,7 @@ TEST_F(MathGradTest, ComplexPow) { dy, test::AsTensor({h(0.f, 2.f), h(2.f, 2.f), h(-2.f, 2.f)}, TensorShape({3}))); } +#endif // TENSORFLOW_USE_SYCL TEST_F(MathGradTest, Maximum) { auto x = test::AsTensor({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, @@ -886,6 +889,8 @@ TEST_F(MathGradTest, MatMul_11) { test::ExpectClose(dy, MatMul(dz, true, x, true)); } +//TODO{lukeiwanski}: Implement BatchMatMul for SYCL +#ifndef TENSORFLOW_USE_SYCL TEST_F(MathGradTest, BatchMatMul_00) { auto x = test::AsTensor({1.f, 2.f, 3.f, 4.f, 5.f, 6.f}, TensorShape({1, 2, 3})); @@ -933,6 +938,7 @@ TEST_F(MathGradTest, BatchMatMul_11) { test::ExpectClose(dx, BatchMatMul(y, true, dz, true)); test::ExpectClose(dy, BatchMatMul(dz, true, x, true)); } +#endif // TENSORFLOW_USE_SYCL TEST_F(MathGradTest, Sum_dim0) { auto x = test::AsTensor({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md index 709dfd53cfb..e9315c0750d 100644 --- a/tensorflow/g3doc/get_started/os_setup.md +++ b/tensorflow/g3doc/get_started/os_setup.md @@ -706,7 +706,7 @@ for production (though it will mature over time). ##### Download and install OpenCL drivers The exact steps required for a functional OpenCL installation will depend on -your environment. For Unbuntu 14.04, the following steps are known to work: +your environment. For Ubuntu 14.04, the following steps are known to work: ```bash sudo apt-get install ocl-icd-opencl-dev opencl-headers @@ -727,12 +727,17 @@ and copy the files into e.g. `/usr/local/computecpp`: ```bash tar -xvzf ComputeCpp-CE-0.1.1-Ubuntu.14.04-64bit.tar.gz -sudo mkdir /usr/local/computecpp sudo cp -R ComputeCpp-CE-0.1.1-Linux /usr/local/computecpp sudo chmod -R a+r /usr/local/computecpp/ sudo chmod -R a+x /usr/local/computecpp/bin ``` +Add the lib folder to your `LD_LIBRARY_PATH` to make Python find `libComputeCpp.so` by adding the following line to your `~/.bash_profile`: + +```bash +export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/computecpp/lib" +``` + ### Prepare environment for Mac OS X We recommend using [homebrew](http://brew.sh) to install the bazel dependency, @@ -1305,6 +1310,12 @@ This is typically because you have the Xcode build tools installed, but you stil When importing tensorflow, you may see an "ImportError" raised. Below are some possible examples and solutions: +``` +ImportError: cannot import name 'pywrap_tensorflow' +``` + +This can occur if you try to import tensorflow while your current working directory is in the same directory as TensorFlow is located. If that is the case, change the working directory (i.e. `cd` in bash or `os.chdir` in python) to some folder outside of the TensorFlow directory and try importing tensorflow again. + ``` ImportError: /lib64/libc.so.6: version `GLIBC_2.16' not found (required by ..._pywrap_tensorflow.so) ``` @@ -1323,3 +1334,8 @@ directory, and so tries to import directly from the source code instead of your installed tensorflow package. Solution: don't import tensorflow from the tensorflow source code root directory, if you are. +``` +ImportError: libComputeCpp.so: cannot open shared object file: No such file or directory +``` + +Make sure you have added the path to ComputeCpp's `lib` folder to your `LD_LIBRARY_PATH` (as mentioned above). diff --git a/tensorflow/python/client/device_lib_test.py b/tensorflow/python/client/device_lib_test.py index 561ce09099b..7bba10efacf 100644 --- a/tensorflow/python/client/device_lib_test.py +++ b/tensorflow/python/client/device_lib_test.py @@ -34,7 +34,7 @@ class DeviceLibTest(test_util.TensorFlowTestCase): # GPU test if test.is_gpu_available(): self.assertGreater(len(devices), 1) - self.assertTrue("GPU" in [d.device_type for d in devices]) + self.assertTrue("GPU" in [d.device_type for d in devices] or "SYCL" in [d.device_type for d in devices]) if __name__ == "__main__": diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index f2fd687adf6..3ea7e547ee1 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -44,7 +44,14 @@ from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import compat from tensorflow.python.util.protobuf import compare +from tensorflow.python.client import device_lib +def gpu_device_name(): + """Returns the name of a GPU device if available or the empty string.""" + for x in device_lib.list_local_devices(): + if x.device_type == 'GPU' or x.device_type == 'SYCL': + return x.name + return '' def assert_ops_in_graph(expected_ops, graph): """Assert all expected operations are found. @@ -301,7 +308,12 @@ class TensorFlowTestCase(googletest.TestCase): sess = self._cached_session with sess.graph.as_default(), sess.as_default(): if force_gpu: - with sess.graph.device("/gpu:0"): + # Use the name of an actual device if one is detected, or '/gpu:0' + # otherwise + gpu_name = gpu_device_name() + if len(gpu_name) == 0: + gpu_name = '/gpu:0' + with sess.graph.device(gpu_name): yield sess elif use_gpu: yield sess @@ -311,7 +323,12 @@ class TensorFlowTestCase(googletest.TestCase): else: with session.Session(graph=graph, config=prepare_config(config)) as sess: if force_gpu: - with sess.graph.device("/gpu:0"): + # Use the name of an actual device if one is detected, or '/gpu:0' + # otherwise + gpu_name = gpu_device_name() + if len(gpu_name) == 0: + gpu_name = '/gpu:0' + with sess.graph.device(gpu_name): yield sess elif use_gpu: yield sess diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 1e510b28689..04c2c0c9e73 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1342,7 +1342,7 @@ class ControlFlowTest(test.TestCase): def _testWhileGrad_ColocateGradients(self, colocate): gpu_dev_name = test.gpu_device_name() if test.is_gpu_available() else "/gpu:0" - gpu_short_name = gpu_dev_name.split('/')[-1] + gpu_short_name = gpu_dev_name.split('/')[-1].lower() with self.test_session(graph=ops.Graph()) as sess: v = constant_op.constant(2.0, name="v") diff --git a/tensorflow/python/kernel_tests/stage_op_test.py b/tensorflow/python/kernel_tests/stage_op_test.py index c2d760809be..e5f449fcf2c 100644 --- a/tensorflow/python/kernel_tests/stage_op_test.py +++ b/tensorflow/python/kernel_tests/stage_op_test.py @@ -31,7 +31,7 @@ class StageTest(test.TestCase): with ops.device('/cpu:0'): x = array_ops.placeholder(dtypes.float32) v = 2. * (array_ops.zeros([1024, 1024]) + x) - with ops.device('/gpu:0'): + with ops.device(test.gpu_device_name()): stager = data_flow_ops.StagingArea([dtypes.float32]) stage = stager.put([v]) y = stager.get() @@ -46,7 +46,7 @@ class StageTest(test.TestCase): with ops.device('/cpu:0'): x = array_ops.placeholder(dtypes.float32) v = 2. * (array_ops.zeros([128, 128]) + x) - with ops.device('/gpu:0'): + with ops.device(test.gpu_device_name()): stager = data_flow_ops.StagingArea([dtypes.float32, dtypes.float32]) stage = stager.put([x, v]) z, y = stager.get() @@ -62,7 +62,7 @@ class StageTest(test.TestCase): with ops.device('/cpu:0'): x = array_ops.placeholder(dtypes.float32) v = 2. * (array_ops.zeros([128, 128]) + x) - with ops.device('/gpu:0'): + with ops.device(test.gpu_device_name()): stager = data_flow_ops.StagingArea( [dtypes.float32, dtypes.float32], shapes=[[], [128, 128]], diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 18cf8f7b99c..501f0c8b352 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -40,6 +40,7 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: disable=unused-import from tensorflow.python.framework.test_util import assert_equal_graph_def from tensorflow.python.framework.test_util import TensorFlowTestCase as TestCase +from tensorflow.python.framework.test_util import gpu_device_name from tensorflow.python.ops.gradient_checker import compute_gradient_error from tensorflow.python.ops.gradient_checker import compute_gradient @@ -105,15 +106,6 @@ def is_gpu_available(cuda_only=False): return any((x.device_type == 'GPU' or x.device_type == 'SYCL') for x in _device_lib.list_local_devices()) - -def gpu_device_name(): - """Returns the name of a GPU device if available or the empty string.""" - for x in _device_lib.list_local_devices(): - if x.device_type == 'GPU' or x.device_type == 'SYCL': - return x.name - return '' - - _allowed_symbols = [ # We piggy-back googletest documentation. 'Benchmark', diff --git a/third_party/sycl/crosstool/computecpp.tpl b/third_party/sycl/crosstool/computecpp.tpl index a5e6b9fe938..66dd9aea7be 100755 --- a/third_party/sycl/crosstool/computecpp.tpl +++ b/third_party/sycl/crosstool/computecpp.tpl @@ -26,9 +26,7 @@ def main(): if(output_file_index == 1): # we are linking - return subprocess.call([CPU_CXX_COMPILER] + compiler_flags) - - compiler_flags = compiler_flags + ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DEIGEN_USE_SYCL=1'] + return subprocess.call([CPU_CXX_COMPILER] + compiler_flags + ['-Wl,--no-undefined']) # find what we compile compiling_cpp = 0 @@ -38,6 +36,28 @@ def main(): if(compited_file_name.endswith(('.cc', '.c++', '.cpp', '.CPP', '.C', '.cxx'))): compiling_cpp = 1; + compiler_flags = compiler_flags + ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DEIGEN_USE_SYCL=1', '-DTENSORFLOW_USE_SYCL', '-DEIGEN_HAS_C99_MATH'] + + if(compiling_cpp == 1): + # create a blacklist of folders that will be skipped when compiling with ComputeCpp + _skip = ["external", "llvm", ".cu.cc"] + # if compiling external project skip computecpp + if any(_folder in _skip for _folder in output_file_name): + return subprocess.call([CPU_CXX_COMPILER] + compiler_flags) + + if(compiling_cpp == 1): + # this is an optimisation that will check if compiled file has to be compiled with ComputeCpp + + _tmp_flags = [flag for flag in compiler_flags if not flag.startswith(('-o', output_file_name))] + # create preprocessed of the file + _cmd = " ".join([CPU_CXX_COMPILER] + _tmp_flags + ["-E"]) + # check if it has parallel_for< in it + _cmd += " | grep \".parallel_for\" > /dev/null" + ps = subprocess.call(_cmd, shell=True) + # if not call CXX compiler + if(ps != 0): + return subprocess.call([CPU_CXX_COMPILER] + compiler_flags) + if(compiling_cpp == 1): filename, file_extension = os.path.splitext(output_file_name) bc_out = filename + '.sycl' @@ -52,9 +72,12 @@ def main(): # dont want that in case of compiling with computecpp first host_compiler_flags = [flag for flag in compiler_flags if not flag.startswith(('-MF', '-MD',)) - if not '.d' in flag] + if not '.d' in flag + ] - host_compiler_flags = ['-D_GLIBCXX_USE_CXX11_ABI=0', '-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, '--include', bc_out] + host_compiler_flags + host_compiler_flags[host_compiler_flags.index('-c')] = "--include" + + host_compiler_flags = ['-xc++', '-D_GLIBCXX_USE_CXX11_ABI=0', '-DTENSORFLOW_USE_SYCL', '-Wno-unused-variable', '-I', COMPUTECPP_INCLUDE, '-c', bc_out] + host_compiler_flags x = subprocess.call([CPU_CXX_COMPILER] + host_compiler_flags) return x else: diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template index 48c9f5aa3f9..3622b9423c2 100644 --- a/tools/bazel.rc.template +++ b/tools/bazel.rc.template @@ -7,7 +7,7 @@ build:sycl --crosstool_top=@local_config_sycl//crosstool:toolchain build:sycl --define=using_sycl=true build:sycl_asan --crosstool_top=@local_config_sycl//crosstool:toolchain -build:sycl_asan --define=using_sycl=true --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -fsanitize=address --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -lasan +build:sycl_asan --define=using_sycl=true --copt -fno-omit-frame-pointer --copt -fsanitize-coverage=3 --copt -DGPR_NO_DIRECT_SYSCALLS --linkopt -fPIC --linkopt -fsanitize=address build --force_python=py$PYTHON_MAJOR_VERSION build --host_force_python=py$PYTHON_MAJOR_VERSION