Threadpool api support for misc ops.
This commit is contained in:
parent
13ce8851cb
commit
dd4585014b
tensorflow/core/kernels
@ -88,7 +88,6 @@ class MklLRNOp : public OpKernel {
|
||||
workspace_enabled_ = false;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("workspace_enabled", &workspace_enabled_));
|
||||
fwd_stream_.reset(new CPU_STREAM(cpu_engine_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -169,6 +168,7 @@ class MklLRNOp : public OpKernel {
|
||||
lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
|
||||
std::vector<primitive> net;
|
||||
fwd_stream_.reset(CreateStream(context, cpu_engine_));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
net.push_back(lrn_forward(lrn_prim_desc));
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
|
@ -130,9 +130,10 @@ class MklRequantizePerChannelOp : public OpKernel {
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(input_mem_prim),
|
||||
GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(output_mem_prim),
|
||||
cpu_engine_, reorder_attr);
|
||||
mkldnn::stream reorder_stream = CPU_STREAM(cpu_engine_);
|
||||
std::shared_ptr<stream> reorder_stream;
|
||||
reorder_stream.reset(CreateStream(ctx, cpu_engine_));
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
reorder_stream.submit(
|
||||
reorder_stream->submit(
|
||||
{mkldnn::reorder(reorder_pd, *input_mem_prim, *output_mem_prim)});
|
||||
#else
|
||||
std::unordered_map<int, mkldnn::memory> reorder_args = {
|
||||
@ -140,7 +141,7 @@ class MklRequantizePerChannelOp : public OpKernel {
|
||||
{MKLDNN_ARG_TO, *output_mem_prim}};
|
||||
std::unique_ptr<mkldnn::primitive> reorder_prim(
|
||||
new mkldnn::reorder(reorder_pd));
|
||||
reorder_prim->execute(reorder_stream, reorder_args);
|
||||
reorder_prim->execute(*reorder_stream, reorder_args);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
Tensor* output_min = nullptr;
|
||||
|
@ -181,22 +181,21 @@ template <typename T>
|
||||
class MklSlicePrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.slice_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
Setup(sliceParams);
|
||||
}
|
||||
|
||||
~MklSlicePrimitive() {}
|
||||
|
||||
void Execute(const MklSliceParams& sliceParams) {
|
||||
void Execute(const MklSliceParams& sliceParams, std::shared_ptr<stream> slice_stream) {
|
||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
|
||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.slice_primitives, context_.slice_stream,
|
||||
execute_primitives(context_.slice_primitives, slice_stream,
|
||||
context_.slice_primitives_args);
|
||||
#else
|
||||
context_.slice_stream->submit(context_.slice_primitives);
|
||||
slice_stream->submit(context_.slice_primitives);
|
||||
#endif
|
||||
|
||||
// We should set it back to DummyData so as to make the primitive
|
||||
@ -228,8 +227,6 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
|
||||
} context_;
|
||||
|
||||
engine cpu_engine_;
|
||||
|
||||
void Setup(const MklSliceParams& sliceParams) {
|
||||
// Actually, DummyData will not be used in computation,
|
||||
// because the real data will be filled before execution.
|
||||
@ -465,7 +462,7 @@ class MklSliceOp : public OpKernel {
|
||||
auto op_md =
|
||||
MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
src.CheckReorderToOpMem(op_md, cpu_engine);
|
||||
src.CheckReorderToOpMem(op_md, cpu_engine, context);
|
||||
#else
|
||||
auto op_pd = memory::primitive_desc(op_md, cpu_engine);
|
||||
src.CheckReorderToOpMem(op_pd);
|
||||
@ -492,7 +489,9 @@ class MklSliceOp : public OpKernel {
|
||||
MklSlicePrimitive<T>* reorder_prim =
|
||||
MklSlicePrimitiveFactory<T>::Get(sliceParams);
|
||||
// Execute slice reorder.
|
||||
reorder_prim->Execute(sliceParams);
|
||||
std::shared_ptr<stream> slice_stream;
|
||||
slice_stream.reset(CreateStream(context, reorder_prim->GetEngine()));
|
||||
reorder_prim->Execute(sliceParams, slice_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
Loading…
Reference in New Issue
Block a user