From 16a10ea5f97ed2c7e0a7132380e355a35a4b9afc Mon Sep 17 00:00:00 2001 From: mdfaijul Date: Sun, 9 Feb 2020 22:41:19 -0800 Subject: [PATCH] Added support for MKLDNN 1.x for QuantizeOpV2 and DequantizeOp. --- tensorflow/core/kernels/mkl_dequantize_op.cc | 54 +++-- tensorflow/core/kernels/mkl_quantize_op.cc | 203 ++++++++----------- tensorflow/core/util/mkl_types.h | 3 + tensorflow/core/util/mkl_util.h | 8 +- 4 files changed, 122 insertions(+), 146 deletions(-) diff --git a/tensorflow/core/kernels/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl_dequantize_op.cc index 4c9dbf4274a..2e046bf85bb 100644 --- a/tensorflow/core/kernels/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl_dequantize_op.cc @@ -17,18 +17,18 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "mkldnn.hpp" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/kernels/meta_support.h" #include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/lib/core/errors.h" - -#include "tensorflow/core/graph/mkl_graph_util.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" -#include "mkldnn.hpp" using mkldnn::primitive_attr; using mkldnn::stream; @@ -51,7 +51,7 @@ class MklDequantizeOp : public OpKernel { void Compute(OpKernelContext* ctx) override { try { // Using CPU device - auto cpu_engine = engine(engine::cpu, 0); + auto cpu_engine = engine(ENGINE_CPU, 0); // Get the inputs const Tensor& src_tensor = MklGetInput(ctx, kSrcIndex); @@ -82,33 +82,28 @@ class MklDequantizeOp : public OpKernel { auto src_md = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), memory::format::nhwc); + : memory::desc(src_dims, MklDnnType(), MEMORY_FORMAT::nhwc); src.SetUsrMem(src_md, &src_tensor); Tensor* output_tensor = nullptr; MklDnnShape output_mkl_shape; TensorShape output_tf_shape; - - memory::primitive_desc src_pd = - memory::primitive_desc(src_md, cpu_engine); memory::desc dst_md = src_mkl_shape.IsMklTensor() ? src_md : memory::desc(src_dims, MklDnnType(), - memory::format::nhwc); - memory::primitive_desc dst_pd = - memory::primitive_desc(dst_md, cpu_engine); - + MEMORY_FORMAT::nhwc); // If input is MKL shape, output is also MKL shape. // If input is TF shape, output is also TF shape. if (src_mkl_shape.IsMklTensor()) { output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetMklLayout(&dst_md); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(), src_mkl_shape.GetSizesAsMklDnnDims(), src_mkl_shape.GetTfDataFormat()); - output_tf_shape.AddDim((dst_pd.get_size() / sizeof(float))); + output_tf_shape.AddDim(GET_MEMORY_SIZE_FROM_MD(dst_md, cpu_engine) / + sizeof(float)); } else { output_mkl_shape.SetMklTensor(false); output_tf_shape = MklDnnDimsToTFShape(output_dims); @@ -135,20 +130,35 @@ class MklDequantizeOp : public OpKernel { const float target_range = static_cast((uint64_t{1} << target_bits) - 1); const float scale_factor = max_abs / target_range; - std::vector scales; scales.push_back(scale_factor); primitive_attr attr; attr.set_output_scales(0, scales); +#ifndef ENABLE_MKLDNN_V1 + // MKL-DNN 1.0 does not provide set_int_output_round_mode() API. + // Also it does not define round_nearest (enum). attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); - mkldnn::reorder::primitive_desc reorder_pd = - mkldnn::reorder::primitive_desc(src_pd, dst_pd, attr); - - // Execute MKL-DNN primitive +#endif // !ENABLE_MKLDNN_V1 + stream reorder_stream = CPU_STREAM(cpu_engine); std::vector net; - net.push_back( - mkldnn::reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem())); - stream(stream::kind::eager).submit(net).wait(); + + // Create reorder primitive and then execute. + auto reorder_pd = REORDER_PD_CONSTRUCTOR_WITH_ATTR( + GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(src.GetUsrMem()), + GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(dst.GetUsrMem()), cpu_engine, + attr); +#ifdef ENABLE_MKLDNN_V1 + net.push_back(reorder(reorder_pd)); + std::vector> reorder_net_args; + reorder_net_args.push_back({{MKLDNN_ARG_FROM, *src.GetUsrMem()}, + { MKLDNN_ARG_TO, + *dst.GetUsrMem() }}); + execute_primitives(net, std::make_shared(reorder_stream), + reorder_net_args); +#else + net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem())); + reorder_stream.submit(net); +#endif // ENABLE_MKLDNN_V1 } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + diff --git a/tensorflow/core/kernels/mkl_quantize_op.cc b/tensorflow/core/kernels/mkl_quantize_op.cc index 985f1cd8c88..d049b5f58d2 100644 --- a/tensorflow/core/kernels/mkl_quantize_op.cc +++ b/tensorflow/core/kernels/mkl_quantize_op.cc @@ -17,9 +17,7 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "mkldnn.h" #include "mkldnn.hpp" -#include "mkldnn_types.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/type_traits.h" @@ -27,6 +25,7 @@ limitations under the License. #include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" using mkldnn::primitive_attr; @@ -56,7 +55,6 @@ enum { } // namespace namespace tensorflow { - typedef Eigen::ThreadPoolDevice CPUDevice; struct MklReorderWithScaleFwdParams { @@ -78,20 +76,28 @@ struct MklReorderWithScaleFwdParams { class MklReorderWithScalePrimitive : public MklPrimitive { public: explicit MklReorderWithScalePrimitive( - const memory* from, const memory* to, - const MklReorderWithScaleFwdParams& fwdParams) { + const MklReorderWithScaleFwdParams& fwdParams) + : cpu_engine_(ENGINE_CPU, 0) { // Create reorder primitive - Setup(from, to, fwdParams); + Setup(fwdParams); } ~MklReorderWithScalePrimitive() {} std::shared_ptr GetPrimitive() { return context_.reorder_prim; } - // set data handles - void SetMemory(const memory* from, const memory* to) { - context_.src_mem->set_data_handle(from->get_data_handle()); - context_.dst_mem->set_data_handle(to->get_data_handle()); + void Execute(void* src_data, void* dst_data) { + context_.src_mem->set_data_handle(src_data); + context_.dst_mem->set_data_handle(dst_data); +#ifndef ENABLE_MKLDNN_V1 + context_.reorder_stream->submit(context_.net); +#else + context_.reorder_prim->execute(*context_.reorder_stream, + context_.prim_args); +#endif // !ENABLE_MKLDNN_V1 + // After execution, set data handle back. + context_.src_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); } private: @@ -101,41 +107,36 @@ class MklReorderWithScalePrimitive : public MklPrimitive { std::shared_ptr src_mem; std::shared_ptr dst_mem; - // Memory desc - std::shared_ptr src_md; - std::shared_ptr dst_md; - - // Memory primitive desc - std::shared_ptr src_mpd; - std::shared_ptr dst_mpd; - // Reorder primitive descriptor and primitive std::shared_ptr reorder_pd; std::shared_ptr reorder_prim; + // Stream and primitive vector + std::shared_ptr reorder_stream; + +#ifndef ENABLE_MKLDNN_V1 + std::vector net; +#else + std::unordered_map prim_args; +#endif // !ENABLE_MKLDNN_V1 + ReorderContext() : src_mem(nullptr), dst_mem(nullptr), - src_md(nullptr), - dst_md(nullptr), - src_mpd(nullptr), - dst_mpd(nullptr), reorder_pd(nullptr), - reorder_prim(nullptr) {} + reorder_prim(nullptr), + reorder_stream(nullptr) {} } context_; - engine cpu_engine_ = engine(engine::cpu, 0); + engine cpu_engine_; // Reorder primitive setup - void Setup(const memory* from, const memory* to, - const MklReorderWithScaleFwdParams& fwdParams) { + void Setup(const MklReorderWithScaleFwdParams& fwdParams) { // Create memory descriptors for reorder data with specified format - context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); - context_.dst_md.reset(new memory::desc(fwdParams.dst_md.data)); - context_.src_mpd.reset( - new memory::primitive_desc(*context_.src_md, cpu_engine_)); - context_.dst_mpd.reset( - new memory::primitive_desc(*context_.dst_md, cpu_engine_)); + context_.src_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD( + fwdParams.src_md, cpu_engine_, DummyData)); + context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD( + fwdParams.dst_md, cpu_engine_, DummyData)); // Check if there is any fusion as post-ops auto const& post_op_params = fwdParams.post_op_params; @@ -147,18 +148,22 @@ class MklReorderWithScalePrimitive : public MklPrimitive { scales.push_back(post_op_params.param[0]); post_ops_attr.set_output_scales(0, scales); - // Create a reorder - context_.reorder_pd = - std::make_shared(reorder::primitive_desc( - *context_.src_mpd, *context_.dst_mpd, post_ops_attr)); + context_.reorder_pd.reset(new REORDER_PD_CONSTRUCTOR_WITH_ATTR( + GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.src_mem), + GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.dst_mem), cpu_engine_, + post_ops_attr)); - // Create memory primitive based on dummy data - context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); - context_.dst_mem.reset(new memory(*context_.dst_mpd, DummyData)); - - // Create reorder primitive - context_.reorder_prim = std::make_shared( - reorder(*context_.reorder_pd, *context_.src_mem, *context_.dst_mem)); +// Create reorder primitive +#ifndef ENABLE_MKLDNN_V1 + context_.reorder_prim.reset(new reorder( + *context_.reorder_pd, *context_.src_mem, *context_.dst_mem)); + context_.net.push_back(*context_.reorder_prim); +#else + context_.reorder_prim.reset(new reorder(*context_.reorder_pd)); + context_.prim_args.insert({MKLDNN_ARG_FROM, *context_.src_mem}); + context_.prim_args.insert({MKLDNN_ARG_TO, *context_.dst_mem}); +#endif // !ENABLE_MKLDNN_V1 + context_.reorder_stream.reset(new CPU_STREAM(cpu_engine_)); } }; @@ -173,11 +178,10 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory { MklReorderWithScalePrimitiveFactory::GetInstance().GetReorder( from, to, fwdParams)); if (reorderPrim == nullptr) { - reorderPrim = new MklReorderWithScalePrimitive(from, to, fwdParams); + reorderPrim = new MklReorderWithScalePrimitive(fwdParams); MklReorderWithScalePrimitiveFactory::GetInstance().SetReorder( from, to, reorderPrim, fwdParams); } - reorderPrim->SetMemory(from, to); return reorderPrim; } @@ -192,20 +196,8 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory { static string CreateKey(const memory* from, const memory* to, const MklReorderWithScaleFwdParams& fwdParams) { - string dtypes = string(""); - string prefix = "reorder"; FactoryKeyCreator key_creator; - auto const& from_desc = from->get_primitive_desc().desc().data; - auto const& to_desc = to->get_primitive_desc().desc().data; - - key_creator.AddAsKey(prefix); - key_creator.AddAsKey(static_cast(from_desc.format)); - key_creator.AddAsKey(static_cast(from_desc.data_type)); - key_creator.AddAsKey(fwdParams.src_dims); - key_creator.AddAsKey(static_cast(to_desc.format)); - key_creator.AddAsKey(static_cast(to_desc.data_type)); - key_creator.AddAsKey(fwdParams.dtypes); - + key_creator.AddAsKey(MklReorderPrimitiveFactory::CreateKey(from, to)); // Generate key for post-op scale if (fwdParams.post_op_params.name == "scale") { DCHECK_EQ(fwdParams.post_op_params.param.size(), 1); @@ -231,21 +223,6 @@ class MklReorderWithScalePrimitiveFactory : public MklPrimitiveFactory { } }; -// Fuction to find (or create) a reorder from memory pointed by -// 'from' to memory pointed by 'to', it will create primitive or -// get primitive from pool if it is cached. -// Returns the primitive. -template -inline primitive FindOrCreateReorder( - const memory* from, const memory* to, - const MklReorderWithScaleFwdParams& fwdParams) { - DCHECK(from); - DCHECK(to); - MklReorderWithScalePrimitive* reorder_prim = - MklReorderWithScalePrimitiveFactory::Get(from, to, fwdParams); - return *reorder_prim->GetPrimitive(); -} - // Quantizes a tensor from float to T, with user-specified min_range and // max_range. template @@ -300,7 +277,7 @@ class MklQuantizeV2Op : public OpKernel { "Scalar calculation in MKL is supported only for" "MIN_FIRST mode for now.")); - auto cpu_engine = engine(engine::cpu, 0); + auto cpu_engine = engine(ENGINE_CPU, 0); const Tensor& input = ctx->input(0); const unsigned int src_idx = 0; const Tensor& src_tensor = MklGetInput(ctx, src_idx); @@ -366,7 +343,7 @@ class MklQuantizeV2Op : public OpKernel { max_range = std::max(input_max_range, min_range + epsilon); // Clamping the max_range to zero since max_range can also be negative. max_range = std::max(0.0f, max_range); - auto cpu_engine = engine(engine::cpu, 0); + auto cpu_engine = engine(ENGINE_CPU, 0); const Tensor& src_tensor = MklGetInput(ctx, src_idx); MklDnnShape src_mkl_shape; GetMklShape(ctx, src_idx, &src_mkl_shape); @@ -377,25 +354,25 @@ class MklQuantizeV2Op : public OpKernel { : TFShapeToMklDnnDims(src_tensor.shape()); auto output_dims = src_dims; // Set the dst layout to be the best mkl layout based on dims and type. - memory::format dst_layout_type; + MEMORY_FORMAT dst_layout_type; switch (src_tf_shape.dims()) { case 0: ComputeScalar(ctx, min_range, max_range); return; case 1: - dst_layout_type = memory::format::x; + dst_layout_type = MEMORY_FORMAT::x; break; case 2: - dst_layout_type = memory::format::nc; + dst_layout_type = MEMORY_FORMAT::nc; break; case 3: - dst_layout_type = memory::format::tnc; + dst_layout_type = MEMORY_FORMAT::tnc; break; case 4: - dst_layout_type = memory::format::nhwc; + dst_layout_type = MEMORY_FORMAT::nhwc; break; case 5: - dst_layout_type = memory::format::ndhwc; + dst_layout_type = MEMORY_FORMAT::ndhwc; break; default: OP_REQUIRES_OK(ctx, @@ -414,11 +391,11 @@ class MklQuantizeV2Op : public OpKernel { // If the mode is min_first, input data has to be subtracted from // min_range, before being scaled auto flat_input = input.flat().data(); - Tensor minfirst_tmpinput; - OP_REQUIRES_OK( - ctx, ctx->allocate_temp(DT_FLOAT, input.shape(), &minfirst_tmpinput)); + Tensor min_shifted_input_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, input.shape(), + &min_shifted_input_tensor)); if (mode_ == QUANTIZE_MODE_MIN_FIRST) { - auto minfirst_input = minfirst_tmpinput.flat().data(); + auto minfirst_input = min_shifted_input_tensor.flat().data(); const Eigen::TensorOpCost cost( sizeof(float), /*load bytes*/ sizeof(float), /*saved bytes*/ @@ -432,25 +409,27 @@ class MklQuantizeV2Op : public OpKernel { }; d.parallelFor(input.NumElements(), cost, ParallelSub); - src.SetUsrMem(src_md, minfirst_input); + src.SetUsrMem(src_md, &min_shifted_input_tensor); } else { src.SetUsrMem(src_md, &src_tensor); } memory::desc dst_md = memory::desc(src_dims, MklDnnType(), dst_layout_type); - auto dst_pd = src.GetUsrMemPrimDesc(); +#ifndef ENABLE_MKLDNN_V1 + auto dst_pd = memory::primitive_desc(dst_md, cpu_engine); +#endif // !ENABLE_MKLDNN_V1 // Standard shape assignments for layout pass MklDnnShape output_mkl_shape; TensorShape output_tf_shape; if (src_mkl_shape.IsMklTensor()) { output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_md); + output_mkl_shape.SetMklLayout(&DST_MD); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(), src_mkl_shape.GetSizesAsMklDnnDims(), src_mkl_shape.GetTfDataFormat()); - output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T)); + output_tf_shape.AddDim(DST_MD.get_size() / sizeof(T)); } else { output_mkl_shape.SetMklTensor(false); output_tf_shape = MklDnnDimsToTFShape(output_dims); @@ -459,6 +438,8 @@ class MklQuantizeV2Op : public OpKernel { Tensor* output_tensor = nullptr; AllocateOutputSetMklShape(ctx, 0, &output_tensor, output_tf_shape, output_mkl_shape); + dst.SetUsrMem(dst_md, output_tensor); + TensorShape min_tf_shape = {}; MklDnnShape min_mkl_shape; min_mkl_shape.SetMklTensor(false); @@ -472,8 +453,6 @@ class MklQuantizeV2Op : public OpKernel { AllocateOutputSetMklShape(ctx, 2, &output_max_tensor, max_tf_shape, max_mkl_shape); - dst.SetUsrMem(dst_md, output_tensor); - float scale_factor = 0; if (mode_ == QUANTIZE_MODE_SCALED) { // Estimating scales for quantization. @@ -497,41 +476,25 @@ class MklQuantizeV2Op : public OpKernel { target_range = static_cast((uint64_t{1} << num_bits) - 1); } scale_factor = target_range / max_abs; - - output_min_tensor->flat()(0) = min_range; - output_max_tensor->flat()(0) = max_range; - - // Primitive creation and stream submit - std::vector scales{scale_factor}; - mkldnn::primitive_attr attr; - attr.set_output_scales(0, scales); - auto reorder_desc = reorder::primitive_desc( - src.GetUsrMemPrimDesc(), dst.GetUsrMemPrimDesc(), attr); - reorder my_reorder = reorder( - reorder_desc, primitive::at(*src.GetUsrMem()), *dst.GetUsrMem()); - std::vector net{my_reorder}; - stream(stream::kind::eager).submit(net).wait(); } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) { // Estimate scale for qunatization const int number_of_bits = sizeof(T) * 8; const int64 number_of_steps = static_cast(1) << number_of_bits; scale_factor = (number_of_steps - 1.0) / (max_range - min_range); - - output_min_tensor->flat()(0) = min_range; - output_max_tensor->flat()(0) = max_range; - - MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md); - fwdParams.dtypes.append(typeid(T).name()); - - fwdParams.post_op_params.name = "scale"; - fwdParams.post_op_params.param.push_back(scale_factor); - - // Get primitive from pool or create one and submit - std::vector net; - net.push_back( - FindOrCreateReorder(src.GetUsrMem(), dst.GetUsrMem(), fwdParams)); - stream(stream::kind::eager).submit(net).wait(); } + + MklReorderWithScaleFwdParams fwdParams(src_dims, src_md, dst_md); + fwdParams.dtypes.append(typeid(T).name()); + fwdParams.post_op_params.name = "scale"; + fwdParams.post_op_params.param.push_back(scale_factor); + + MklReorderWithScalePrimitive* reorder_prim = + MklReorderWithScalePrimitiveFactory::Get(src.GetUsrMem(), + dst.GetUsrMem(), fwdParams); + reorder_prim->Execute(src.GetUsrMemDataHandle(), dst.GetUsrMemDataHandle()); + + output_min_tensor->flat()(0) = min_range; + output_max_tensor->flat()(0) = max_range; } private: diff --git a/tensorflow/core/util/mkl_types.h b/tensorflow/core/util/mkl_types.h index eede9b6087f..558c57a1851 100644 --- a/tensorflow/core/util/mkl_types.h +++ b/tensorflow/core/util/mkl_types.h @@ -39,6 +39,7 @@ namespace tensorflow { #define GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr) mem_ptr->get_desc() #define GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(mem_ptr) \ GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr) +#define GET_MEMORY_SIZE_FROM_MD(md, engine) md.get_size() #define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc() #define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) op_pd->diff_dst_desc() #define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc() @@ -131,6 +132,8 @@ namespace tensorflow { #define GET_BLOCK_STRIDES(strides, idx) strides[(idx)] #define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \ { {dims}, MklDnnType(), fm } +#define GET_MEMORY_SIZE_FROM_MD(md, engine) \ + memory::primitive_desc(md, engine).get_size() #define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc() #define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) \ op_pd.get()->diff_dst_primitive_desc() diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index a782e76547b..582b0525323 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -2078,10 +2078,6 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory { return instance_; } - private: - MklReorderPrimitiveFactory() {} - ~MklReorderPrimitiveFactory() {} - static string CreateKey(const memory* from, const memory* to) { string prefix = "reorder"; FactoryKeyCreator key_creator; @@ -2117,6 +2113,10 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory { return key_creator.GetKey(); } + private: + MklReorderPrimitiveFactory() {} + ~MklReorderPrimitiveFactory() {} + MklPrimitive* GetReorder(const memory* from, const memory* to) { string key = CreateKey(from, to); return this->GetOp(key);