diff --git a/tensorflow/core/kernels/adjust_contrast_op.cc b/tensorflow/core/kernels/adjust_contrast_op.cc index c8f12f91a6c..37976f71837 100644 --- a/tensorflow/core/kernels/adjust_contrast_op.cc +++ b/tensorflow/core/kernels/adjust_contrast_op.cc @@ -31,6 +31,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // AdjustContrastOp is deprecated as of GraphDef version >= 2 @@ -410,4 +413,25 @@ REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU), AdjustContrastOpv2); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +template <> +class AdjustContrastOpv2 : public AdjustContrastOpV2Base { + public: + explicit AdjustContrastOpv2(OpKernelConstruction* context) + : AdjustContrastOpV2Base(context) {} + + void DoCompute(OpKernelContext* context, + const ComputeOptions& options) override { + const int64 shape[4] = {options.batch, options.height, options.width, + options.channels}; + functor::AdjustContrastv2()( + context->eigen_device(), + options.input->shaped(shape), options.factor->scalar(), + options.output->shaped(shape)); + } +}; +REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_SYCL), + AdjustContrastOpv2); +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow diff --git a/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc index ffd47406eb6..c485f148448 100644 --- a/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc +++ b/tensorflow/core/kernels/adjust_contrast_op_benchmark_test.cc @@ -56,6 +56,11 @@ static Graph* BM_AdjustContrast(int batches, int width, int height) { // BM_AdjustContrast_cpu_1_299_299 179084 340186 2181 751.9M items/s // BM_AdjustContrast_gpu_32_299_299 85276 123665 4189 2.9G items/s BM_AdjustContrastDev(cpu, 1, 299, 299); +#if GOOGLE_CUDA BM_AdjustContrastDev(gpu, 32, 299, 299); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_AdjustContrastDev(sycl, 32, 299, 299); +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow