Threadpool api support for misc ops.

This commit is contained in:
Srinivasan Narayanamoorthy 2020-05-12 13:18:53 -07:00
parent 13ce8851cb
commit dd4585014b
3 changed files with 13 additions and 13 deletions

View File

@ -88,7 +88,6 @@ class MklLRNOp : public OpKernel {
workspace_enabled_ = false; workspace_enabled_ = false;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
context->GetAttr("workspace_enabled", &workspace_enabled_)); context->GetAttr("workspace_enabled", &workspace_enabled_));
fwd_stream_.reset(new CPU_STREAM(cpu_engine_));
} }
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
@ -169,6 +168,7 @@ class MklLRNOp : public OpKernel {
lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_)); lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
std::vector<primitive> net; std::vector<primitive> net;
fwd_stream_.reset(CreateStream(context, cpu_engine_));
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
net.push_back(lrn_forward(lrn_prim_desc)); net.push_back(lrn_forward(lrn_prim_desc));
std::vector<std::unordered_map<int, memory>> net_args; std::vector<std::unordered_map<int, memory>> net_args;

View File

@ -130,9 +130,10 @@ class MklRequantizePerChannelOp : public OpKernel {
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(input_mem_prim), GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(input_mem_prim),
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(output_mem_prim), GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(output_mem_prim),
cpu_engine_, reorder_attr); cpu_engine_, reorder_attr);
mkldnn::stream reorder_stream = CPU_STREAM(cpu_engine_); std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine_));
#ifndef ENABLE_MKLDNN_V1 #ifndef ENABLE_MKLDNN_V1
reorder_stream.submit( reorder_stream->submit(
{mkldnn::reorder(reorder_pd, *input_mem_prim, *output_mem_prim)}); {mkldnn::reorder(reorder_pd, *input_mem_prim, *output_mem_prim)});
#else #else
std::unordered_map<int, mkldnn::memory> reorder_args = { std::unordered_map<int, mkldnn::memory> reorder_args = {
@ -140,7 +141,7 @@ class MklRequantizePerChannelOp : public OpKernel {
{MKLDNN_ARG_TO, *output_mem_prim}}; {MKLDNN_ARG_TO, *output_mem_prim}};
std::unique_ptr<mkldnn::primitive> reorder_prim( std::unique_ptr<mkldnn::primitive> reorder_prim(
new mkldnn::reorder(reorder_pd)); new mkldnn::reorder(reorder_pd));
reorder_prim->execute(reorder_stream, reorder_args); reorder_prim->execute(*reorder_stream, reorder_args);
#endif // !ENABLE_MKLDNN_V1 #endif // !ENABLE_MKLDNN_V1
Tensor* output_min = nullptr; Tensor* output_min = nullptr;

View File

@ -181,22 +181,21 @@ template <typename T>
class MklSlicePrimitive : public MklPrimitive { class MklSlicePrimitive : public MklPrimitive {
public: public:
explicit MklSlicePrimitive(const MklSliceParams& sliceParams) explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
: cpu_engine_(ENGINE_CPU, 0) { : MklPrimitive(engine(ENGINE_CPU, 0)) {
context_.slice_stream.reset(new CPU_STREAM(cpu_engine_));
Setup(sliceParams); Setup(sliceParams);
} }
~MklSlicePrimitive() {} ~MklSlicePrimitive() {}
void Execute(const MklSliceParams& sliceParams) { void Execute(const MklSliceParams& sliceParams, std::shared_ptr<stream> slice_stream) {
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle()); context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle()); context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.slice_primitives, context_.slice_stream, execute_primitives(context_.slice_primitives, slice_stream,
context_.slice_primitives_args); context_.slice_primitives_args);
#else #else
context_.slice_stream->submit(context_.slice_primitives); slice_stream->submit(context_.slice_primitives);
#endif #endif
// We should set it back to DummyData so as to make the primitive // We should set it back to DummyData so as to make the primitive
@ -228,8 +227,6 @@ class MklSlicePrimitive : public MklPrimitive {
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {} : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_; } context_;
engine cpu_engine_;
void Setup(const MklSliceParams& sliceParams) { void Setup(const MklSliceParams& sliceParams) {
// Actually, DummyData will not be used in computation, // Actually, DummyData will not be used in computation,
// because the real data will be filled before execution. // because the real data will be filled before execution.
@ -465,7 +462,7 @@ class MklSliceOp : public OpKernel {
auto op_md = auto op_md =
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides); MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
src.CheckReorderToOpMem(op_md, cpu_engine); src.CheckReorderToOpMem(op_md, cpu_engine, context);
#else #else
auto op_pd = memory::primitive_desc(op_md, cpu_engine); auto op_pd = memory::primitive_desc(op_md, cpu_engine);
src.CheckReorderToOpMem(op_pd); src.CheckReorderToOpMem(op_pd);
@ -492,7 +489,9 @@ class MklSliceOp : public OpKernel {
MklSlicePrimitive<T>* reorder_prim = MklSlicePrimitive<T>* reorder_prim =
MklSlicePrimitiveFactory<T>::Get(sliceParams); MklSlicePrimitiveFactory<T>::Get(sliceParams);
// Execute slice reorder. // Execute slice reorder.
reorder_prim->Execute(sliceParams); std::shared_ptr<stream> slice_stream;
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
reorder_prim->Execute(sliceParams, slice_stream);
} catch (mkldnn::error& e) { } catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) + string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " + ", message: " + string(e.message) + ", in file " +