[Intel MKL] Adding DNNL ops (part 2) supporting threadpool work
This commit is contained in:
parent
7c694e10b5
commit
5016da3128
@ -178,6 +178,9 @@ class MklAddNOp : public OpKernel {
|
|||||||
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
|
dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_data_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
|
fwd_cpu_stream.reset(CreateStream(ctx, cpu_engine));
|
||||||
|
|
||||||
// Create memory descriptor for MKL-DNN.
|
// Create memory descriptor for MKL-DNN.
|
||||||
// If all input in Tensorflow format, create block memory descriptor,
|
// If all input in Tensorflow format, create block memory descriptor,
|
||||||
// else convert TF format to MKL memory descriptor
|
// else convert TF format to MKL memory descriptor
|
||||||
@ -215,6 +218,7 @@ class MklAddNOp : public OpKernel {
|
|||||||
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
|
srcs_pd.push_back(memory::primitive_desc(md, cpu_engine));
|
||||||
#endif
|
#endif
|
||||||
src.SetUsrMem(md, &src_tensor);
|
src.SetUsrMem(md, &src_tensor);
|
||||||
|
src.SetUsrMemDataHandle(&src_tensor, fwd_cpu_stream);
|
||||||
inputs.push_back(src.GetOpMem());
|
inputs.push_back(src.GetOpMem());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,11 +244,10 @@ class MklAddNOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
|
AllocateOutputSetMklShape(ctx, kOutputIdx, &dst_tensor, output_tf_shape,
|
||||||
output_mkl_shape);
|
output_mkl_shape);
|
||||||
dst.SetUsrMemDataHandle(dst_tensor);
|
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||||
|
|
||||||
// Create Sum op, and submit net for execution.
|
// Create Sum op, and submit net for execution.
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
stream* fwd_cpu_stream = CreateStream(ctx, cpu_engine);
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
mkldnn::sum sum_op(sum_pd);
|
mkldnn::sum sum_op(sum_pd);
|
||||||
std::unordered_map<int, memory> net_args = {
|
std::unordered_map<int, memory> net_args = {
|
||||||
|
@ -281,11 +281,19 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
||||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
context_.data_mem_shdptr[i]->set_data_handle(
|
||||||
|
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
|
||||||
|
}
|
||||||
|
context_.dst_mem->set_data_handle(
|
||||||
|
static_cast<void*>(dst_data.get_data_handle()), *fwd_stream);
|
||||||
|
#else
|
||||||
context_.data_mem_shdptr[i]->set_data_handle(
|
context_.data_mem_shdptr[i]->set_data_handle(
|
||||||
static_cast<void*>(in_data[i].get_data_handle()));
|
static_cast<void*>(in_data[i].get_data_handle()));
|
||||||
}
|
}
|
||||||
context_.dst_mem->set_data_handle(
|
context_.dst_mem->set_data_handle(
|
||||||
static_cast<void*>(dst_data.get_data_handle()));
|
static_cast<void*>(dst_data.get_data_handle()));
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
|
|
||||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||||
context_.data_mem[i] = *context_.data_mem_shdptr[i];
|
context_.data_mem[i] = *context_.data_mem_shdptr[i];
|
||||||
@ -788,11 +796,13 @@ class MklConcatOp : public OpKernel {
|
|||||||
dnn_shape_dst);
|
dnn_shape_dst);
|
||||||
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
|
DCHECK(dst_tensor != nullptr) << "Output tensor pointer is NULL";
|
||||||
|
|
||||||
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
|
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
|
||||||
|
|
||||||
if (dnn_shape_dst.IsMklTensor())
|
if (dnn_shape_dst.IsMklTensor())
|
||||||
dst_md = dnn_shape_dst.GetMklLayout();
|
dst_md = dnn_shape_dst.GetMklLayout();
|
||||||
dst.SetUsrMem(dst_md, dst_tensor);
|
dst.SetUsrMem(dst_md, dst_tensor);
|
||||||
std::shared_ptr<stream> fwd_cpu_stream;
|
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||||
fwd_cpu_stream.reset(CreateStream(context, cpu_engine));
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
auto concat_op = concat(concat_pd);
|
auto concat_op = concat(concat_pd);
|
||||||
std::unordered_map<int, memory> net_args = {
|
std::unordered_map<int, memory> net_args = {
|
||||||
@ -830,9 +840,10 @@ class MklConcatOp : public OpKernel {
|
|||||||
|
|
||||||
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
|
dst_md = dnn_shape_dst.IsMklTensor() ? dnn_shape_dst.GetMklLayout()
|
||||||
: dst_md;
|
: dst_md;
|
||||||
dst.SetUsrMem(dst_md, dst_tensor);
|
|
||||||
std::shared_ptr<stream> fwd_cpu_stream;
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
|
fwd_cpu_stream.reset(CreateStream(context, concat_fwd->GetEngine()));
|
||||||
|
dst.SetUsrMem(dst_md, dst_tensor);
|
||||||
|
dst.SetUsrMemDataHandle(dst_tensor, fwd_cpu_stream);
|
||||||
// Execute concat
|
// Execute concat
|
||||||
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
|
concat_fwd->Execute(srcs_mem, dst.GetOpMem(), concat_fwd_dims,
|
||||||
fwd_cpu_stream);
|
fwd_cpu_stream);
|
||||||
|
@ -75,6 +75,9 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
MklDnnData<T> src(&cpu_engine);
|
MklDnnData<T> src(&cpu_engine);
|
||||||
MklDnnData<float> dst(&cpu_engine);
|
MklDnnData<float> dst(&cpu_engine);
|
||||||
|
|
||||||
|
std::shared_ptr<stream> reorder_stream;
|
||||||
|
reorder_stream.reset(CreateStream(ctx, cpu_engine));
|
||||||
|
|
||||||
// If input is in MKL layout, then simply grab input layout; otherwise,
|
// If input is in MKL layout, then simply grab input layout; otherwise,
|
||||||
// construct input TF layout. For TF layout, although input shape
|
// construct input TF layout. For TF layout, although input shape
|
||||||
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
|
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
|
||||||
@ -85,6 +88,7 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
|
: memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc);
|
||||||
|
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
|
src.SetUsrMemDataHandle(&src_tensor, reorder_stream);
|
||||||
|
|
||||||
Tensor* output_tensor = nullptr;
|
Tensor* output_tensor = nullptr;
|
||||||
MklDnnShape output_mkl_shape;
|
MklDnnShape output_mkl_shape;
|
||||||
@ -129,6 +133,7 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
|
AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape,
|
||||||
output_mkl_shape);
|
output_mkl_shape);
|
||||||
dst.SetUsrMem(dst_md, output_tensor);
|
dst.SetUsrMem(dst_md, output_tensor);
|
||||||
|
dst.SetUsrMemDataHandle(output_tensor, reorder_stream);
|
||||||
|
|
||||||
// The quantization logic here for mode SCALED is similar to the logic
|
// The quantization logic here for mode SCALED is similar to the logic
|
||||||
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
|
// in QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
|
||||||
@ -155,8 +160,6 @@ class MklDequantizeOp : public OpKernel {
|
|||||||
// Also it does not define round_nearest (enum).
|
// Also it does not define round_nearest (enum).
|
||||||
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest);
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
std::shared_ptr<stream> reorder_stream;
|
|
||||||
reorder_stream.reset(CreateStream(ctx, cpu_engine));
|
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
|
|
||||||
// Create reorder primitive and then execute.
|
// Create reorder primitive and then execute.
|
||||||
|
@ -137,6 +137,7 @@ class MklLRNOp : public OpKernel {
|
|||||||
// that input is in NHWC layout with Channel being the last dimension.
|
// that input is in NHWC layout with Channel being the last dimension.
|
||||||
src_dnn_data.SetUsrMem(src_md, &src_tensor);
|
src_dnn_data.SetUsrMem(src_md, &src_tensor);
|
||||||
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
|
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
|
||||||
|
src_dnn_data.SetUsrMemDataHandle(&src_tensor, fwd_stream_);
|
||||||
|
|
||||||
// dst_dnn_data has the same shape as input.
|
// dst_dnn_data has the same shape as input.
|
||||||
dst_dnn_data.SetUsrMem(src_md);
|
dst_dnn_data.SetUsrMem(src_md);
|
||||||
@ -157,7 +158,7 @@ class MklLRNOp : public OpKernel {
|
|||||||
&output_tensor);
|
&output_tensor);
|
||||||
OP_REQUIRES_OK(context, context->status());
|
OP_REQUIRES_OK(context, context->status());
|
||||||
DCHECK(output_tensor != nullptr);
|
DCHECK(output_tensor != nullptr);
|
||||||
dst_dnn_data.SetUsrMemDataHandle(output_tensor);
|
dst_dnn_data.SetUsrMemDataHandle(output_tensor, fwd_stream_);
|
||||||
|
|
||||||
// Handle workspace required for MKL-DNN.
|
// Handle workspace required for MKL-DNN.
|
||||||
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
|
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
|
||||||
@ -393,6 +394,7 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
orig_input_dnn_shape.GetSizesAsMklDnnDims();
|
orig_input_dnn_shape.GetSizesAsMklDnnDims();
|
||||||
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
|
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
|
||||||
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
||||||
|
orig_input_dnn_data.SetUsrMemDataHandle(&orig_input_tensor, bwd_stream_);
|
||||||
|
|
||||||
// output_dnn_data has the same shape as original input
|
// output_dnn_data has the same shape as original input
|
||||||
output_dnn_data.SetUsrMem(orig_input_md);
|
output_dnn_data.SetUsrMem(orig_input_md);
|
||||||
@ -421,7 +423,7 @@ class MklLRNGradOp : public OpKernel {
|
|||||||
orig_input_format, &output_tensor);
|
orig_input_format, &output_tensor);
|
||||||
OP_REQUIRES_OK(context, context->status());
|
OP_REQUIRES_OK(context, context->status());
|
||||||
DCHECK(output_tensor != nullptr);
|
DCHECK(output_tensor != nullptr);
|
||||||
output_dnn_data.SetUsrMemDataHandle(output_tensor);
|
output_dnn_data.SetUsrMemDataHandle(output_tensor, bwd_stream_);
|
||||||
|
|
||||||
// Create LRN primitive and add it to the net
|
// Create LRN primitive and add it to the net
|
||||||
// At this point, workspace is enabled, so we don't need
|
// At this point, workspace is enabled, so we don't need
|
||||||
|
@ -137,6 +137,7 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||||||
memory::dims out_strides =
|
memory::dims out_strides =
|
||||||
ReorderStrides(CalculateTFStrides(out_dims), perm);
|
ReorderStrides(CalculateTFStrides(out_dims), perm);
|
||||||
|
|
||||||
|
std::shared_ptr<stream> transpose_stream;
|
||||||
in.SetUsrMem(in_dims, in_strides, &in_tensor);
|
in.SetUsrMem(in_dims, in_strides, &in_tensor);
|
||||||
// Output dimensions are same as input dimensions. We adjust the layout
|
// Output dimensions are same as input dimensions. We adjust the layout
|
||||||
// using strides.
|
// using strides.
|
||||||
@ -144,16 +145,16 @@ Status MKLTransposeND(OpKernelContext* context, const Tensor& in_tensor,
|
|||||||
|
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
std::shared_ptr<stream> transpose_stream;
|
|
||||||
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
auto* prim = FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem());
|
||||||
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
transpose_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||||
|
in.SetUsrMemDataHandle(&in_tensor, transpose_stream);
|
||||||
|
out.SetUsrMemDataHandle(out_tensor, transpose_stream);
|
||||||
net.push_back(*(prim->GetPrimitive()));
|
net.push_back(*(prim->GetPrimitive()));
|
||||||
std::vector<MemoryArgsMap> net_args;
|
std::vector<MemoryArgsMap> net_args;
|
||||||
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
net_args.push_back({{MKLDNN_ARG_FROM, *in.GetUsrMem()},
|
||||||
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
{MKLDNN_ARG_TO, *out.GetUsrMem()}});
|
||||||
execute_primitives(net, transpose_stream, net_args);
|
execute_primitives(net, transpose_stream, net_args);
|
||||||
#else
|
#else
|
||||||
std::shared_ptr<stream> transpose_stream;
|
|
||||||
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
transpose_stream.reset(new CPU_STREAM(cpu_engine));
|
||||||
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
|
net.push_back(FindOrCreateReorder<T>(in.GetUsrMem(), out.GetUsrMem()));
|
||||||
transpose_stream->submit(net).wait();
|
transpose_stream->submit(net).wait();
|
||||||
|
@ -1524,17 +1524,27 @@ class MklDnnData {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set function for data buffer of user memory primitive.
|
/// Set function for data buffer of user memory primitive.
|
||||||
inline void SetUsrMemDataHandle(void* data_buffer) {
|
inline void SetUsrMemDataHandle(void* data_buffer,
|
||||||
|
std::shared_ptr<stream> t_stream = nullptr) {
|
||||||
CHECK_NOTNULL(user_memory_);
|
CHECK_NOTNULL(user_memory_);
|
||||||
CHECK_NOTNULL(data_buffer);
|
CHECK_NOTNULL(data_buffer);
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
user_memory_->set_data_handle(data_buffer, *t_stream);
|
||||||
|
#else
|
||||||
user_memory_->set_data_handle(data_buffer);
|
user_memory_->set_data_handle(data_buffer);
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set function for data buffer of user memory primitive.
|
/// Set function for data buffer of user memory primitive.
|
||||||
inline void SetUsrMemDataHandle(const Tensor* tensor) {
|
inline void SetUsrMemDataHandle(const Tensor* tensor,
|
||||||
|
std::shared_ptr<stream> t_stream = nullptr) {
|
||||||
CHECK_NOTNULL(user_memory_);
|
CHECK_NOTNULL(user_memory_);
|
||||||
CHECK_NOTNULL(tensor);
|
CHECK_NOTNULL(tensor);
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
user_memory_->set_data_handle(GetTensorBuffer(tensor), *t_stream);
|
||||||
|
#else
|
||||||
user_memory_->set_data_handle(GetTensorBuffer(tensor));
|
user_memory_->set_data_handle(GetTensorBuffer(tensor));
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
}
|
}
|
||||||
|
|
||||||
/// allocate function for data buffer
|
/// allocate function for data buffer
|
||||||
|
Loading…
x
Reference in New Issue
Block a user