DNN 0.x cleanup - MKL quantize/dequantize ops
This commit is contained in:
		
							parent
							
								
									eb0ea35638
								
							
						
					
					
						commit
						8cdd7fc86e
					
				| @ -26,7 +26,6 @@ limitations under the License. | ||||
| #include "tensorflow/core/kernels/meta_support.h" | ||||
| #include "tensorflow/core/kernels/quantization_utils.h" | ||||
| #include "tensorflow/core/lib/core/errors.h" | ||||
| #include "tensorflow/core/util/mkl_types.h" | ||||
| #include "tensorflow/core/util/mkl_util.h" | ||||
| 
 | ||||
| using mkldnn::primitive_attr; | ||||
| @ -51,7 +50,7 @@ class MklDequantizeOp : public OpKernel { | ||||
|   void Compute(OpKernelContext* ctx) override { | ||||
|     try { | ||||
|       // Using CPU device
 | ||||
|       auto cpu_engine = engine(ENGINE_CPU, 0); | ||||
|       auto cpu_engine = engine(engine::kind::cpu, 0); | ||||
| 
 | ||||
|       // Get the inputs
 | ||||
|       const Tensor& src_tensor = MklGetInput(ctx, kSrcIndex); | ||||
| @ -82,10 +81,10 @@ class MklDequantizeOp : public OpKernel { | ||||
|       // construct input TF layout. For TF layout, although input shape
 | ||||
|       // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
 | ||||
|       // layout
 | ||||
|       auto src_md = | ||||
|           src_mkl_shape.IsMklTensor() | ||||
|               ? src_mkl_shape.GetMklLayout() | ||||
|               : memory::desc(src_dims, MklDnnType<T>(), MEMORY_FORMAT::nhwc); | ||||
|       auto src_md = src_mkl_shape.IsMklTensor() | ||||
|                         ? src_mkl_shape.GetMklLayout() | ||||
|                         : memory::desc(src_dims, MklDnnType<T>(), | ||||
|                                        memory::format_tag::nhwc); | ||||
| 
 | ||||
|       src.SetUsrMem(src_md, &src_tensor); | ||||
|       src.SetUsrMemDataHandle(&src_tensor, reorder_stream); | ||||
| @ -93,14 +92,6 @@ class MklDequantizeOp : public OpKernel { | ||||
|       Tensor* output_tensor = nullptr; | ||||
|       MklDnnShape output_mkl_shape; | ||||
|       TensorShape output_tf_shape; | ||||
| #ifndef ENABLE_MKLDNN_V1 | ||||
|       memory::desc dst_md = | ||||
|           src_mkl_shape.IsMklTensor() | ||||
|               ? memory::desc(src_dims, MklDnnType<float>(), | ||||
|                              static_cast<MEMORY_FORMAT>(src_md.data.format)) | ||||
|               : memory::desc(src_dims, MklDnnType<float>(), | ||||
|                              MEMORY_FORMAT::nhwc); | ||||
| #else | ||||
|       memory::desc dst_md = memory::desc(); | ||||
|       if (src_mkl_shape.IsMklTensor()) { | ||||
|         dst_md = memory::desc(src_mkl_shape.GetMklLayout().data); | ||||
| @ -108,10 +99,9 @@ class MklDequantizeOp : public OpKernel { | ||||
|         // same .data field but different type.
 | ||||
|         dst_md.data.data_type = memory::convert_to_c(MklDnnType<float>()); | ||||
|       } else { | ||||
|         dst_md = | ||||
|             memory::desc(src_dims, MklDnnType<float>(), MEMORY_FORMAT::nhwc); | ||||
|         dst_md = memory::desc(src_dims, MklDnnType<float>(), | ||||
|                               memory::format_tag::nhwc); | ||||
|       } | ||||
| #endif  // !ENABLE_MKLDNN_V1
 | ||||
| 
 | ||||
|       // If input is MKL shape, output is also MKL shape.
 | ||||
|       // If input is TF shape, output is also TF shape.
 | ||||
| @ -122,8 +112,7 @@ class MklDequantizeOp : public OpKernel { | ||||
|         output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(), | ||||
|                                      src_mkl_shape.GetSizesAsMklDnnDims(), | ||||
|                                      src_mkl_shape.GetTfDataFormat()); | ||||
|         output_tf_shape.AddDim(GET_MEMORY_SIZE_FROM_MD(dst_md, cpu_engine) / | ||||
|                                sizeof(float)); | ||||
|         output_tf_shape.AddDim(dst_md.get_size() / sizeof(float)); | ||||
|       } else { | ||||
|         output_mkl_shape.SetMklTensor(false); | ||||
|         output_tf_shape = MklDnnDimsToTFShape(output_dims); | ||||
| @ -155,33 +144,21 @@ class MklDequantizeOp : public OpKernel { | ||||
|       scales.push_back(scale_factor); | ||||
|       primitive_attr attr; | ||||
|       attr.set_output_scales(0, scales); | ||||
| #ifndef ENABLE_MKLDNN_V1 | ||||
|       // MKL-DNN 1.0 does not provide set_int_output_round_mode() API.
 | ||||
|       // Also it does not define round_nearest (enum).
 | ||||
|       attr.set_int_output_round_mode(mkldnn::round_mode::round_nearest); | ||||
| #endif  // !ENABLE_MKLDNN_V1
 | ||||
|       std::vector<primitive> net; | ||||
| 
 | ||||
|       // Create reorder primitive and then execute.
 | ||||
|       auto reorder_pd = REORDER_PD_CONSTRUCTOR_WITH_ATTR( | ||||
|           GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(src.GetUsrMem()), | ||||
|           GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(dst.GetUsrMem()), cpu_engine, | ||||
|           attr); | ||||
| #ifdef ENABLE_MKLDNN_V1 | ||||
|       auto reorder_pd = | ||||
|           ReorderPd(cpu_engine, src.GetUsrMem()->get_desc(), cpu_engine, | ||||
|                     dst.GetUsrMem()->get_desc(), attr); | ||||
|       net.push_back(reorder(reorder_pd)); | ||||
|       std::vector<std::unordered_map<int, memory>> reorder_net_args; | ||||
|       reorder_net_args.push_back({{MKLDNN_ARG_FROM, *src.GetUsrMem()}, | ||||
|                                   { MKLDNN_ARG_TO, | ||||
|                                     *dst.GetUsrMem() }}); | ||||
|                                   {MKLDNN_ARG_TO, *dst.GetUsrMem()}}); | ||||
|       execute_primitives(net, reorder_stream, reorder_net_args); | ||||
| #else | ||||
|       net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *dst.GetUsrMem())); | ||||
|       reorder_stream->submit(net); | ||||
| #endif  // ENABLE_MKLDNN_V1
 | ||||
|     } catch (mkldnn::error& e) { | ||||
|       string error_msg = "Status: " + std::to_string(e.status) + | ||||
|                          ", message: " + string(e.message) + ", in file " + | ||||
|                          string(__FILE__) + ":" + std::to_string(__LINE__); | ||||
|       string error_msg = "Status: " + std::to_string(e.status) + ", message: " + | ||||
|                          string(e.message) + ", in file " + string(__FILE__) + | ||||
|                          ":" + std::to_string(__LINE__); | ||||
|       OP_REQUIRES_OK( | ||||
|           ctx, errors::Aborted("Operation received an exception:", error_msg)); | ||||
|     } | ||||
|  | ||||
| @ -25,7 +25,6 @@ limitations under the License. | ||||
| #include "tensorflow/core/graph/mkl_graph_util.h" | ||||
| #include "tensorflow/core/lib/core/errors.h" | ||||
| #include "tensorflow/core/platform/logging.h" | ||||
| #include "tensorflow/core/util/mkl_types.h" | ||||
| #include "tensorflow/core/util/mkl_util.h" | ||||
| 
 | ||||
| using mkldnn::primitive_attr; | ||||
| @ -77,7 +76,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive { | ||||
|  public: | ||||
|   explicit MklReorderWithScalePrimitive( | ||||
|       const MklReorderWithScaleFwdParams& fwdParams) | ||||
|       : MklPrimitive(engine(ENGINE_CPU, 0)) { | ||||
|       : MklPrimitive(engine(engine::kind::cpu, 0)) { | ||||
|     // Create reorder primitive
 | ||||
|     Setup(fwdParams); | ||||
|   } | ||||
| @ -95,11 +94,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive { | ||||
|     context_.src_mem->set_data_handle(src_data); | ||||
|     context_.dst_mem->set_data_handle(dst_data); | ||||
| #endif  // ENABLE_MKLDNN_THREADPOOL
 | ||||
| #ifndef ENABLE_MKLDNN_V1 | ||||
|     reorder_stream->submit(context_.net); | ||||
| #else | ||||
|     context_.reorder_prim->execute(*reorder_stream, context_.prim_args); | ||||
| #endif  // !ENABLE_MKLDNN_V1
 | ||||
|     // After execution, set data handle back.
 | ||||
|     context_.src_mem->set_data_handle(DummyData); | ||||
|     context_.dst_mem->set_data_handle(DummyData); | ||||
| @ -119,11 +114,7 @@ class MklReorderWithScalePrimitive : public MklPrimitive { | ||||
|     // Stream and primitive vector
 | ||||
|     std::shared_ptr<mkldnn::stream> reorder_stream; | ||||
| 
 | ||||
| #ifndef ENABLE_MKLDNN_V1 | ||||
|     std::vector<mkldnn::primitive> net; | ||||
| #else | ||||
|     std::unordered_map<int, mkldnn::memory> prim_args; | ||||
| #endif  // !ENABLE_MKLDNN_V1
 | ||||
| 
 | ||||
|     ReorderContext() | ||||
|         : src_mem(nullptr), | ||||
| @ -135,10 +126,10 @@ class MklReorderWithScalePrimitive : public MklPrimitive { | ||||
|   // Reorder primitive setup
 | ||||
|   void Setup(const MklReorderWithScaleFwdParams& fwdParams) { | ||||
|     // Create memory descriptors for reorder data with specified format
 | ||||
|     context_.src_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD( | ||||
|         fwdParams.src_md, cpu_engine_, DummyData)); | ||||
|     context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD( | ||||
|         fwdParams.dst_md, cpu_engine_, DummyData)); | ||||
|     context_.src_mem.reset( | ||||
|         new memory(fwdParams.src_md, cpu_engine_, DummyData)); | ||||
|     context_.dst_mem.reset( | ||||
|         new memory(fwdParams.dst_md, cpu_engine_, DummyData)); | ||||
| 
 | ||||
|     // Check if there is any fusion as post-ops
 | ||||
|     auto const& post_op_params = fwdParams.post_op_params; | ||||
| @ -150,21 +141,14 @@ class MklReorderWithScalePrimitive : public MklPrimitive { | ||||
|     scales.push_back(post_op_params.param[0]); | ||||
|     post_ops_attr.set_output_scales(0, scales); | ||||
| 
 | ||||
|     context_.reorder_pd.reset(new REORDER_PD_CONSTRUCTOR_WITH_ATTR( | ||||
|         GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.src_mem), | ||||
|         GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.dst_mem), cpu_engine_, | ||||
|         post_ops_attr)); | ||||
|     context_.reorder_pd.reset( | ||||
|         new ReorderPd(cpu_engine_, context_.src_mem->get_desc(), cpu_engine_, | ||||
|                       context_.dst_mem->get_desc(), post_ops_attr)); | ||||
| 
 | ||||
| // Create reorder primitive
 | ||||
| #ifndef ENABLE_MKLDNN_V1 | ||||
|     context_.reorder_prim.reset(new reorder( | ||||
|         *context_.reorder_pd, *context_.src_mem, *context_.dst_mem)); | ||||
|     context_.net.push_back(*context_.reorder_prim); | ||||
| #else | ||||
|     // Create reorder primitive
 | ||||
|     context_.reorder_prim.reset(new reorder(*context_.reorder_pd)); | ||||
|     context_.prim_args.insert({MKLDNN_ARG_FROM, *context_.src_mem}); | ||||
|     context_.prim_args.insert({MKLDNN_ARG_TO, *context_.dst_mem}); | ||||
| #endif  // !ENABLE_MKLDNN_V1
 | ||||
|   } | ||||
| }; | ||||
| 
 | ||||
| @ -232,9 +216,8 @@ class MklQuantizeV2Op : public OpKernel { | ||||
|   explicit MklQuantizeV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { | ||||
|     string mode_string; | ||||
|     OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string)); | ||||
|     OP_REQUIRES(ctx, | ||||
|                 (mode_string == "MIN_COMBINED" || mode_string == "MIN_FIRST" || | ||||
|                  mode_string == "SCALED"), | ||||
|     OP_REQUIRES(ctx, (mode_string == "MIN_COMBINED" || | ||||
|                       mode_string == "MIN_FIRST" || mode_string == "SCALED"), | ||||
|                 errors::InvalidArgument("Mode string must be 'MIN_COMBINED'," | ||||
|                                         " 'MIN_FIRST', or 'SCALED', is '" + | ||||
|                                         mode_string + "'")); | ||||
| @ -248,9 +231,8 @@ class MklQuantizeV2Op : public OpKernel { | ||||
| 
 | ||||
|     string round_mode_string; | ||||
|     OP_REQUIRES_OK(ctx, ctx->GetAttr("round_mode", &round_mode_string)); | ||||
|     OP_REQUIRES(ctx, | ||||
|                 (round_mode_string == "HALF_AWAY_FROM_ZERO" || | ||||
|                  round_mode_string == "HALF_TO_EVEN"), | ||||
|     OP_REQUIRES(ctx, (round_mode_string == "HALF_AWAY_FROM_ZERO" || | ||||
|                       round_mode_string == "HALF_TO_EVEN"), | ||||
|                 errors::InvalidArgument("Round mode string must be " | ||||
|                                         "'HALF_AWAY_FROM_ZERO' or " | ||||
|                                         "'HALF_TO_EVEN', is '" + | ||||
| @ -278,7 +260,7 @@ class MklQuantizeV2Op : public OpKernel { | ||||
|                     "Scalar calculation in MKL is supported only for" | ||||
|                     "MIN_FIRST mode for now.")); | ||||
| 
 | ||||
|     auto cpu_engine = engine(ENGINE_CPU, 0); | ||||
|     auto cpu_engine = engine(engine::kind::cpu, 0); | ||||
|     const Tensor& input = ctx->input(0); | ||||
|     const unsigned int src_idx = 0; | ||||
|     const Tensor& src_tensor = MklGetInput(ctx, src_idx); | ||||
| @ -344,7 +326,7 @@ class MklQuantizeV2Op : public OpKernel { | ||||
|     max_range = std::max(input_max_range, min_range + epsilon); | ||||
|     // Clamping the max_range to zero since max_range can also be negative.
 | ||||
|     max_range = std::max(0.0f, max_range); | ||||
|     auto cpu_engine = engine(ENGINE_CPU, 0); | ||||
|     auto cpu_engine = engine(engine::kind::cpu, 0); | ||||
|     const Tensor& src_tensor = MklGetInput(ctx, src_idx); | ||||
|     MklDnnShape src_mkl_shape; | ||||
|     GetMklShape(ctx, src_idx, &src_mkl_shape); | ||||
| @ -355,25 +337,25 @@ class MklQuantizeV2Op : public OpKernel { | ||||
|                         : TFShapeToMklDnnDims(src_tensor.shape()); | ||||
|     auto output_dims = src_dims; | ||||
|     // Set the dst layout to be the best mkl layout based on dims and type.
 | ||||
|     MEMORY_FORMAT dst_layout_type; | ||||
|     memory::format_tag dst_layout_type; | ||||
|     switch (src_tf_shape.dims()) { | ||||
|       case 0: | ||||
|         ComputeScalar(ctx, min_range, max_range); | ||||
|         return; | ||||
|       case 1: | ||||
|         dst_layout_type = MEMORY_FORMAT::x; | ||||
|         dst_layout_type = memory::format_tag::x; | ||||
|         break; | ||||
|       case 2: | ||||
|         dst_layout_type = MEMORY_FORMAT::nc; | ||||
|         dst_layout_type = memory::format_tag::nc; | ||||
|         break; | ||||
|       case 3: | ||||
|         dst_layout_type = MEMORY_FORMAT::tnc; | ||||
|         dst_layout_type = memory::format_tag::tnc; | ||||
|         break; | ||||
|       case 4: | ||||
|         dst_layout_type = MEMORY_FORMAT::nhwc; | ||||
|         dst_layout_type = memory::format_tag::nhwc; | ||||
|         break; | ||||
|       case 5: | ||||
|         dst_layout_type = MEMORY_FORMAT::ndhwc; | ||||
|         dst_layout_type = memory::format_tag::ndhwc; | ||||
|         break; | ||||
|       default: | ||||
|         OP_REQUIRES_OK(ctx, | ||||
| @ -417,9 +399,7 @@ class MklQuantizeV2Op : public OpKernel { | ||||
| 
 | ||||
|     memory::desc dst_md = | ||||
|         memory::desc(src_dims, MklDnnType<T>(), dst_layout_type); | ||||
| #ifndef ENABLE_MKLDNN_V1 | ||||
|     auto dst_pd = memory::primitive_desc(dst_md, cpu_engine); | ||||
| #endif  // !ENABLE_MKLDNN_V1
 | ||||
| 
 | ||||
|     // Standard shape assignments for layout pass
 | ||||
|     MklDnnShape output_mkl_shape; | ||||
|     TensorShape output_tf_shape; | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user