Threadpool api support for misc ops.

This commit is contained in:
Srinivasan Narayanamoorthy 2020-05-12 13:18:53 -07:00
parent 13ce8851cb
commit dd4585014b
3 changed files with 13 additions and 13 deletions

View File

@ -88,7 +88,6 @@ class MklLRNOp : public OpKernel {
workspace_enabled_ = false;
OP_REQUIRES_OK(context,
context->GetAttr("workspace_enabled", &workspace_enabled_));
fwd_stream_.reset(new CPU_STREAM(cpu_engine_));
}
void Compute(OpKernelContext* context) override {
@ -169,6 +168,7 @@ class MklLRNOp : public OpKernel {
lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
std::vector<primitive> net;
fwd_stream_.reset(CreateStream(context, cpu_engine_));
#ifdef ENABLE_MKLDNN_V1
net.push_back(lrn_forward(lrn_prim_desc));
std::vector<std::unordered_map<int, memory>> net_args;

View File

@ -130,9 +130,10 @@ class MklRequantizePerChannelOp : public OpKernel {
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(input_mem_prim),
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(output_mem_prim),
cpu_engine_, reorder_attr);
mkldnn::stream reorder_stream = CPU_STREAM(cpu_engine_);
std::shared_ptr<stream> reorder_stream;
reorder_stream.reset(CreateStream(ctx, cpu_engine_));
#ifndef ENABLE_MKLDNN_V1
reorder_stream.submit(
reorder_stream->submit(
{mkldnn::reorder(reorder_pd, *input_mem_prim, *output_mem_prim)});
#else
std::unordered_map<int, mkldnn::memory> reorder_args = {
@ -140,7 +141,7 @@ class MklRequantizePerChannelOp : public OpKernel {
{MKLDNN_ARG_TO, *output_mem_prim}};
std::unique_ptr<mkldnn::primitive> reorder_prim(
new mkldnn::reorder(reorder_pd));
reorder_prim->execute(reorder_stream, reorder_args);
reorder_prim->execute(*reorder_stream, reorder_args);
#endif // !ENABLE_MKLDNN_V1
Tensor* output_min = nullptr;

View File

@ -181,22 +181,21 @@ template <typename T>
class MklSlicePrimitive : public MklPrimitive {
public:
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.slice_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
Setup(sliceParams);
}
~MklSlicePrimitive() {}
void Execute(const MklSliceParams& sliceParams) {
void Execute(const MklSliceParams& sliceParams, std::shared_ptr<stream> slice_stream) {
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.slice_primitives, context_.slice_stream,
execute_primitives(context_.slice_primitives, slice_stream,
context_.slice_primitives_args);
#else
context_.slice_stream->submit(context_.slice_primitives);
slice_stream->submit(context_.slice_primitives);
#endif
// We should set it back to DummyData so as to make the primitive
@ -228,8 +227,6 @@ class MklSlicePrimitive : public MklPrimitive {
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_;
engine cpu_engine_;
void Setup(const MklSliceParams& sliceParams) {
// Actually, DummyData will not be used in computation,
// because the real data will be filled before execution.
@ -465,7 +462,7 @@ class MklSliceOp : public OpKernel {
auto op_md =
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
#ifdef ENABLE_MKLDNN_V1
src.CheckReorderToOpMem(op_md, cpu_engine);
src.CheckReorderToOpMem(op_md, cpu_engine, context);
#else
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
src.CheckReorderToOpMem(op_pd);
@ -492,7 +489,9 @@ class MklSliceOp : public OpKernel {
MklSlicePrimitive<T>* reorder_prim =
MklSlicePrimitiveFactory<T>::Get(sliceParams);
// Execute slice reorder.
reorder_prim->Execute(sliceParams);
std::shared_ptr<stream> slice_stream;
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
reorder_prim->Execute(sliceParams, slice_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +