diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 4c3cea4b6ff..12581d0bfa5 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -114,6 +114,21 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { void Execute(const T* src_data, const T* diff_filter_data, const T* diff_bias_data, const T* diff_dst_data, std::shared_ptr bwd_filter_stream) { + // TODO: Create a common function and avoid the duplicate code +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *bwd_filter_stream); + context_.diff_filter_mem->set_data_handle( + static_cast(const_cast(diff_filter_data)), + *bwd_filter_stream); + if (diff_bias_data != nullptr) { + context_.diff_bias_mem->set_data_handle( + static_cast(const_cast(diff_bias_data)), + *bwd_filter_stream); + } + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data)), *bwd_filter_stream); +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.diff_filter_mem->set_data_handle( @@ -124,7 +139,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { } context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream, context_.bwd_filter_primitives_args); diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index f9c8d11c67c..7177431029a 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -116,13 +116,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive { void Execute(const T* diff_src_data, const T* filter_data, const T* diff_dst_data, std::shared_ptr bwd_input_stream) { + // TODO: Create a common function and avoid the duplicate code +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.diff_src_mem->set_data_handle( + static_cast(const_cast(diff_src_data)), *bwd_input_stream); + context_.filter_mem->set_data_handle( + static_cast(const_cast(filter_data)), *bwd_input_stream); + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data)), *bwd_input_stream); +#else context_.diff_src_mem->set_data_handle( static_cast(const_cast(diff_src_data))); context_.filter_mem->set_data_handle( static_cast(const_cast(filter_data))); context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 execute_primitives(context_.bwd_input_primitives, bwd_input_stream, context_.bwd_input_primitives_args); diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 7d0510d03ac..210044436aa 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -110,6 +110,19 @@ class MklConvFwdPrimitive : public MklPrimitive { void Execute(const Tinput* src_data, const Tfilter* filter_data, const Tbias* bias_data, const Toutput* dst_data, std::shared_ptr fwd_stream) { + // TODO: Create a common function and avoid the duplicate code +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *fwd_stream); + context_.filter_mem->set_data_handle( + static_cast(const_cast(filter_data)), *fwd_stream); + if (bias_data != nullptr) { + context_.bias_mem->set_data_handle( + static_cast(const_cast(bias_data)), *fwd_stream); + } + context_.dst_mem->set_data_handle( + static_cast(const_cast(dst_data)), *fwd_stream); +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.filter_mem->set_data_handle( @@ -120,6 +133,7 @@ class MklConvFwdPrimitive : public MklPrimitive { } context_.dst_mem->set_data_handle( static_cast(const_cast(dst_data))); +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_primitives_args.size()); diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 954ae0492df..3b2c4f84039 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -94,6 +94,28 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { void Execute(const T* src_data, const U* weights_data, T* dst_data, U* mean_data, U* variance_data, std::shared_ptr fwd_stream, U* workspace_data) { + // TODO: Create a common function and avoid the duplicate code +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *fwd_stream); + context_.dst_mem->set_data_handle(static_cast(dst_data), + *fwd_stream); + + if (IS_SET(use_scale_shift)) + context_.weights_mem->set_data_handle( + static_cast(const_cast(weights_data)), *fwd_stream); + + if ((context_.pkind == prop_kind::forward_training) || + (IS_SET(use_global_stats))) { + context_.mean_mem->set_data_handle(static_cast(mean_data), + *fwd_stream); + context_.variance_mem->set_data_handle(static_cast(variance_data), + *fwd_stream); + } + if (workspace_data != nullptr) { + context_.ws_mem->set_data_handle(workspace_data, *fwd_stream); + } +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); @@ -110,6 +132,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { if (workspace_data != nullptr) { context_.ws_mem->set_data_handle(workspace_data); } +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 // Execute batch-normalization forward primitives. execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); @@ -503,6 +526,27 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { const T* diff_dst_data, const U* weights_data, T* diff_src_data, U* diff_weights_data, U* res_space_data, std::shared_ptr bwd_stream) { + // TODO: Create a common function and avoid the duplicate code +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *bwd_stream); + context_.mean_mem->set_data_handle( + static_cast(const_cast(mean_data)), *bwd_stream); + context_.variance_mem->set_data_handle( + static_cast(const_cast(variance_data)), *bwd_stream); + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data)), *bwd_stream); + + if (IS_SET(use_scale_shift)) { + context_.weights_mem->set_data_handle( + static_cast(const_cast(weights_data)), *bwd_stream); + context_.diff_weights_mem->set_data_handle( + static_cast(diff_weights_data), *bwd_stream); + } + + context_.diff_src_mem->set_data_handle(static_cast(diff_src_data), + *bwd_stream); +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.mean_mem->set_data_handle( @@ -520,7 +564,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { } context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 // Execute backward batch-normalization primitives. DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index 2dfc6db0075..5f1c9129ec3 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -127,6 +127,17 @@ template void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, void* ws_data, std::shared_ptr fwd_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *fwd_stream); + context_.dst_mem->set_data_handle(static_cast(dst_data), *fwd_stream); + if (context_.alg_kind == ALGORITHM::pooling_max && + context_.prop_kind == + prop_kind::forward_training) { // Max pooling must have workspace. + DCHECK(ws_data != nullptr); + context_.ws_mem->set_data_handle(ws_data, *fwd_stream); + } +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); @@ -136,7 +147,7 @@ void MklPoolingFwdPrimitive::Execute(const T* src_data, T* dst_data, DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(ws_data); } - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); #else @@ -269,6 +280,16 @@ template void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data, std::shared_ptr bwd_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data)), *bwd_stream); + context_.diff_src_mem->set_data_handle(static_cast(diff_src_data), + *bwd_stream); + if (context_.alg_kind == ALGORITHM::pooling_max) { + DCHECK(ws_data != nullptr); + context_.ws_mem->set_data_handle(const_cast(ws_data), *bwd_stream); + } +#else context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); @@ -276,7 +297,7 @@ void MklPoolingBwdPrimitive::Execute(const T* diff_dst_data, DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(const_cast(ws_data)); } - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); #else diff --git a/tensorflow/core/kernels/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl_quantize_op.cc index 5adb9862250..177cbb43d0b 100644 --- a/tensorflow/core/kernels/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl_quantize_op.cc @@ -88,8 +88,13 @@ class MklReorderWithScalePrimitive : public MklPrimitive { void Execute(void* src_data, void* dst_data, std::shared_ptr reorder_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle(src_data, *reorder_stream); + context_.dst_mem->set_data_handle(dst_data, *reorder_stream); +#else context_.src_mem->set_data_handle(src_data); context_.dst_mem->set_data_handle(dst_data); +#endif // ENABLE_MKLDNN_THREADPOOL #ifndef ENABLE_MKLDNN_V1 reorder_stream->submit(context_.net); #else diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 784bbc682dc..9af580de777 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -79,10 +79,16 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { // dst_data: output data buffer of dst void Execute(const T* src_data, T* dst_data, std::shared_ptr fwd_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *fwd_stream); + context_.dst_mem->set_data_handle(static_cast(dst_data), + *fwd_stream); +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_primitives_args.size()); @@ -293,12 +299,20 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { // diff_src_data: output data buffer of diff_src void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data, std::shared_ptr bwd_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *bwd_stream); + context_.diff_dst_mem->set_data_handle( + static_cast(const_cast(diff_dst_data)), *bwd_stream); + context_.diff_src_mem->set_data_handle(static_cast(diff_src_data), + *bwd_stream); +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 DCHECK_EQ(context_.bwd_primitives.size(), context_.bwd_primitives_args.size()); diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc index 4115691c79d..7e293e14d98 100644 --- a/tensorflow/core/kernels/mkl_slice_op.cc +++ b/tensorflow/core/kernels/mkl_slice_op.cc @@ -189,9 +189,15 @@ class MklSlicePrimitive : public MklPrimitive { void Execute(const MklSliceParams& sliceParams, std::shared_ptr slice_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(), + *slice_stream); + context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(), + *slice_stream); +#else context_.src_mem->set_data_handle(sliceParams.from->get_data_handle()); context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle()); - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 execute_primitives(context_.slice_primitives, slice_stream, context_.slice_primitives_args); diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index 4d1cf90f28d..2f51573fe13 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -59,10 +59,16 @@ class MklSoftmaxPrimitive : public MklPrimitive { // dst_data: output data buffer of dst void Execute(const T* src_data, T* dst_data, std::shared_ptr fwd_cpu_stream) { +#ifdef ENABLE_MKLDNN_THREADPOOL + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data)), *fwd_cpu_stream); + context_.dst_mem->set_data_handle(static_cast(dst_data), + *fwd_cpu_stream); +#else context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); - +#endif // ENABLE_MKLDNN_THREADPOOL #ifdef ENABLE_MKLDNN_V1 DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size()); execute_primitives(context_.fwd_primitives, fwd_cpu_stream,