[OpenCL] Registers AdjustContrastv2 (#10949)
* [OpenCL] Registers AdjustContrastv2 (#93) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL (#96) * [OpenCL] Extended adjust_contrast_op_benchmark_test for OpenCL * simplified to #ifndef * Changed to "#if GOOGLE_CUDA" * Update adjust_contrast_op_benchmark_test.cc * Added comments
This commit is contained in:
parent
5248a48f00
commit
832894ef89
@ -31,6 +31,9 @@ namespace tensorflow {
|
|||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
|
#endif
|
||||||
|
|
||||||
// AdjustContrastOp is deprecated as of GraphDef version >= 2
|
// AdjustContrastOp is deprecated as of GraphDef version >= 2
|
||||||
|
|
||||||
@ -410,4 +413,25 @@ REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_GPU),
|
|||||||
AdjustContrastOpv2<GPUDevice>);
|
AdjustContrastOpv2<GPUDevice>);
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
template <>
|
||||||
|
class AdjustContrastOpv2<SYCLDevice> : public AdjustContrastOpV2Base {
|
||||||
|
public:
|
||||||
|
explicit AdjustContrastOpv2(OpKernelConstruction* context)
|
||||||
|
: AdjustContrastOpV2Base(context) {}
|
||||||
|
|
||||||
|
void DoCompute(OpKernelContext* context,
|
||||||
|
const ComputeOptions& options) override {
|
||||||
|
const int64 shape[4] = {options.batch, options.height, options.width,
|
||||||
|
options.channels};
|
||||||
|
functor::AdjustContrastv2<SYCLDevice>()(
|
||||||
|
context->eigen_device<SYCLDevice>(),
|
||||||
|
options.input->shaped<float, 4>(shape), options.factor->scalar<float>(),
|
||||||
|
options.output->shaped<float, 4>(shape));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AdjustContrastv2").Device(DEVICE_SYCL),
|
||||||
|
AdjustContrastOpv2<SYCLDevice>);
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -56,6 +56,11 @@ static Graph* BM_AdjustContrast(int batches, int width, int height) {
|
|||||||
// BM_AdjustContrast_cpu_1_299_299 179084 340186 2181 751.9M items/s
|
// BM_AdjustContrast_cpu_1_299_299 179084 340186 2181 751.9M items/s
|
||||||
// BM_AdjustContrast_gpu_32_299_299 85276 123665 4189 2.9G items/s
|
// BM_AdjustContrast_gpu_32_299_299 85276 123665 4189 2.9G items/s
|
||||||
BM_AdjustContrastDev(cpu, 1, 299, 299);
|
BM_AdjustContrastDev(cpu, 1, 299, 299);
|
||||||
|
#if GOOGLE_CUDA
|
||||||
BM_AdjustContrastDev(gpu, 32, 299, 299);
|
BM_AdjustContrastDev(gpu, 32, 299, 299);
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
|
BM_AdjustContrastDev(sycl, 32, 299, 299);
|
||||||
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user