diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index 532dbaa79b4..a11e7ebcbf5 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -88,7 +88,6 @@ class MklLRNOp : public OpKernel { workspace_enabled_ = false; OP_REQUIRES_OK(context, context->GetAttr("workspace_enabled", &workspace_enabled_)); - fwd_stream_.reset(new CPU_STREAM(cpu_engine_)); } void Compute(OpKernelContext* context) override { @@ -169,6 +168,7 @@ class MklLRNOp : public OpKernel { lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_)); std::vector<primitive> net; + fwd_stream_.reset(CreateStream(context, cpu_engine_)); #ifdef ENABLE_MKLDNN_V1 net.push_back(lrn_forward(lrn_prim_desc)); std::vector<std::unordered_map<int, memory>> net_args; diff --git a/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc b/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc index f7b72f77cb9..0a0464f648b 100644 --- a/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl_requantize_per_channel_op.cc @@ -130,9 +130,10 @@ class MklRequantizePerChannelOp : public OpKernel { GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(input_mem_prim), GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(output_mem_prim), cpu_engine_, reorder_attr); - mkldnn::stream reorder_stream = CPU_STREAM(cpu_engine_); + std::shared_ptr<stream> reorder_stream; + reorder_stream.reset(CreateStream(ctx, cpu_engine_)); #ifndef ENABLE_MKLDNN_V1 - reorder_stream.submit( + reorder_stream->submit( {mkldnn::reorder(reorder_pd, *input_mem_prim, *output_mem_prim)}); #else std::unordered_map<int, mkldnn::memory> reorder_args = { @@ -140,7 +141,7 @@ class MklRequantizePerChannelOp : public OpKernel { {MKLDNN_ARG_TO, *output_mem_prim}}; std::unique_ptr<mkldnn::primitive> reorder_prim( new mkldnn::reorder(reorder_pd)); - reorder_prim->execute(reorder_stream, reorder_args); + reorder_prim->execute(*reorder_stream, reorder_args); #endif // !ENABLE_MKLDNN_V1 Tensor* output_min = nullptr; diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc index 699c3d44eb7..02471f4a6f6 100644 --- a/tensorflow/core/kernels/mkl_slice_op.cc +++ b/tensorflow/core/kernels/mkl_slice_op.cc @@ -181,22 +181,21 @@ template <typename T> class MklSlicePrimitive : public MklPrimitive { public: explicit MklSlicePrimitive(const MklSliceParams& sliceParams) - : cpu_engine_(ENGINE_CPU, 0) { - context_.slice_stream.reset(new CPU_STREAM(cpu_engine_)); + : MklPrimitive(engine(ENGINE_CPU, 0)) { Setup(sliceParams); } ~MklSlicePrimitive() {} - void Execute(const MklSliceParams& sliceParams) { + void Execute(const MklSliceParams& sliceParams, std::shared_ptr<stream> slice_stream) { context_.src_mem->set_data_handle(sliceParams.from->get_data_handle()); context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle()); #ifdef ENABLE_MKLDNN_V1 - execute_primitives(context_.slice_primitives, context_.slice_stream, + execute_primitives(context_.slice_primitives, slice_stream, context_.slice_primitives_args); #else - context_.slice_stream->submit(context_.slice_primitives); + slice_stream->submit(context_.slice_primitives); #endif // We should set it back to DummyData so as to make the primitive @@ -228,8 +227,6 @@ class MklSlicePrimitive : public MklPrimitive { : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {} } context_; - engine cpu_engine_; - void Setup(const MklSliceParams& sliceParams) { // Actually, DummyData will not be used in computation, // because the real data will be filled before execution. @@ -465,7 +462,7 @@ class MklSliceOp : public OpKernel { auto op_md = MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides); #ifdef ENABLE_MKLDNN_V1 - src.CheckReorderToOpMem(op_md, cpu_engine); + src.CheckReorderToOpMem(op_md, cpu_engine, context); #else auto op_pd = memory::primitive_desc(op_md, cpu_engine); src.CheckReorderToOpMem(op_pd); @@ -492,7 +489,9 @@ class MklSliceOp : public OpKernel { MklSlicePrimitive<T>* reorder_prim = MklSlicePrimitiveFactory<T>::Get(sliceParams); // Execute slice reorder. - reorder_prim->Execute(sliceParams); + std::shared_ptr<stream> slice_stream; + slice_stream.reset(CreateStream(context, reorder_prim->GetEngine())); + reorder_prim->Execute(sliceParams, slice_stream); } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " +