diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 269513f2e7d..4c3cea4b6ff 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -97,9 +97,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { public: explicit MklConvBwdFilterPrimitive( const MklConvBwdFilterParams& convBwdFilterDims) - : cpu_engine_(ENGINE_CPU, 0) { - context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_)); - + : MklPrimitive(engine(ENGINE_CPU, 0)) { // Create convolution backward filter primitive. if (context_.conv_bwd_filter == nullptr) { Setup(convBwdFilterDims); @@ -114,7 +112,8 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { // diff_bias_data: output data buffer for diff_bias // diff_dst_data: input data buffer for diff_dst void Execute(const T* src_data, const T* diff_filter_data, - const T* diff_bias_data, const T* diff_dst_data) { + const T* diff_bias_data, const T* diff_dst_data, + std::shared_ptr bwd_filter_stream) { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.diff_filter_mem->set_data_handle( @@ -127,11 +126,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { static_cast(const_cast(diff_dst_data))); #ifdef ENABLE_MKLDNN_V1 - execute_primitives(context_.bwd_filter_primitives, - context_.bwd_filter_stream, + execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream, context_.bwd_filter_primitives_args); #else - context_.bwd_filter_stream->submit(context_.bwd_filter_primitives); + bwd_filter_stream->submit(context_.bwd_filter_primitives); #endif context_.src_mem->set_data_handle(DummyData); @@ -147,8 +145,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { // diff_filter_data: output data buffer of diff_filter // diff_dst_data: input data buffer of diff_dst void Execute(const T* src_data, const T* diff_filter_data, - const T* diff_dst_data) { - Execute(src_data, diff_filter_data, nullptr, diff_dst_data); + const T* diff_dst_data, + std::shared_ptr bwd_filter_stream) { + Execute(src_data, diff_filter_data, nullptr, diff_dst_data, + bwd_filter_stream); } #ifndef ENABLE_MKLDNN_V1 @@ -223,8 +223,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { src_md(nullptr), diff_filter_md(nullptr), diff_bias_md(nullptr), - diff_dst_md(nullptr), - bwd_filter_stream(nullptr) { + diff_dst_md(nullptr) { } }; @@ -345,7 +344,6 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { } struct ConvBwdFilterContext context_; - engine cpu_engine_; }; template @@ -600,8 +598,10 @@ class MklConvCustomBackpropFilterOp auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc(); if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) { src.SetUsrMem(fwd_src_md, &src_tensor); - src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_)); + src.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_SRC, + cpu_engine_), + context); src_data = static_cast(src.GetOpMem().get_data_handle()); } else { src_data = static_cast(const_cast(src_tensor.flat().data())); @@ -612,8 +612,10 @@ class MklConvCustomBackpropFilterOp if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd, conv_bwd_filter)) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_)); + diff_dst.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, + cpu_engine_), + context); diff_dst_data = static_cast(diff_dst.GetOpMem().get_data_handle()); } else { diff_dst_data = @@ -646,18 +648,21 @@ class MklConvCustomBackpropFilterOp } // Execute convolution backward filter. + std::shared_ptr bwd_cpu_stream; + bwd_cpu_stream.reset(CreateStream(context, conv_bwd_filter->GetEngine())); if (bias_enabled) { T* diff_bias_data = static_cast(const_cast(diff_bias_tensor->flat().data())); conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data, - diff_dst_data); + diff_dst_data, bwd_cpu_stream); } else { - conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data); + conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data, + bwd_cpu_stream); } // Reorder diff_filter back to Tensorflow layout if necessary. if (diff_filter_reorder_required) { - diff_filter.InsertReorderToUserMem(); + diff_filter.InsertReorderToUserMem(context); } // Delete primitive since it is not cached. diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index bcd0446b748..f9c8d11c67c 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -99,9 +99,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { public: explicit MklConvBwdInputPrimitive( const MklConvBwdInputParams& convBwdInputDims) - : cpu_engine_(ENGINE_CPU, 0) { - context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_)); - + : MklPrimitive(engine(ENGINE_CPU, 0)) { // Create conv bwd input primitive if (context_.conv_bwd_input == nullptr) { Setup(convBwdInputDims); @@ -116,7 +114,8 @@ class MklConvBwdInputPrimitive : public MklPrimitive { // diff_dst_data: input data buffer for dst // Bias does not matter here void Execute(const T* diff_src_data, const T* filter_data, - const T* diff_dst_data) { + const T* diff_dst_data, + std::shared_ptr bwd_input_stream) { context_.diff_src_mem->set_data_handle( static_cast(const_cast(diff_src_data))); context_.filter_mem->set_data_handle( @@ -125,10 +124,10 @@ class MklConvBwdInputPrimitive : public MklPrimitive { static_cast(const_cast(diff_dst_data))); #ifdef ENABLE_MKLDNN_V1 - execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream, + execute_primitives(context_.bwd_input_primitives, bwd_input_stream, context_.bwd_input_primitives_args); #else - context_.bwd_input_stream->submit(context_.bwd_input_primitives); + bwd_input_stream->submit(context_.bwd_input_primitives); #endif // ENABLE_MKLDNN_V1 // Set data handle back to DummyData. @@ -180,7 +179,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive { std::shared_ptr diff_dst_md; // MKL-DNN pipeline for executing primitives. - std::shared_ptr bwd_input_stream; std::vector bwd_input_primitives; #ifdef ENABLE_MKLDNN_V1 @@ -203,8 +201,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive { fwd_pd(nullptr), diff_src_md(nullptr), filter_md(nullptr), - diff_dst_md(nullptr), - bwd_input_stream(nullptr) { + diff_dst_md(nullptr) { } }; @@ -290,7 +287,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive { } struct ConvBwdInputContext context_; - engine cpu_engine_; }; template @@ -522,8 +518,10 @@ class MklConvCustomBackpropInputOp if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd, conv_bwd_input)) { filter.SetUsrMem(fwd_filter_md, &filter_tensor); - filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_)); + filter.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, + cpu_engine_), + context); filter_data = static_cast(filter.GetOpMem().get_data_handle()); } else { filter_data = @@ -535,23 +533,29 @@ class MklConvCustomBackpropInputOp if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd, conv_bwd_input)) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_)); + diff_dst.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, + cpu_engine_), + context); diff_dst_data = static_cast(diff_dst.GetOpMem().get_data_handle()); } else { diff_dst_data = static_cast(const_cast(diff_dst_tensor.flat().data())); } + std::shared_ptr bwd_cpu_stream; + bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine())); // Execute conv bwd input primitive. if (!eager_mode) { - conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); + conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data, + bwd_cpu_stream); } else { // In eager mode we first write the output to temporary // buffer in MKL format. Then we convert the data to TF format. T* tmp_data = static_cast(const_cast(tmp_tensor.flat().data())); - conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data); + conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data, + bwd_cpu_stream); auto output_tf_md = diff_src_mkl_shape.GetTfLayout(); #ifndef ENABLE_MKLDNN_V1 auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_); @@ -563,7 +567,7 @@ class MklConvCustomBackpropInputOp memory* dst_data_mem = new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data); CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem, - cpu_engine_); + cpu_engine_, context); } // Delete primitive since it is not cached.