Merge pull request #39519 from Intel-tensorflow:sriniva2/threadpool_pooling
PiperOrigin-RevId: 313437523 Change-Id: I9ebb625cc949eef464de2fcbb0ce77635e7c41e8
This commit is contained in:
		
						commit
						54e57d69d2
					
				| @ -136,9 +136,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { | |||||||
|       const T* src_data = input_tensor.flat<T>().data(); |       const T* src_data = input_tensor.flat<T>().data(); | ||||||
| 
 | 
 | ||||||
|       T* dst_data = output_tensor->flat<T>().data(); |       T* dst_data = output_tensor->flat<T>().data(); | ||||||
| 
 |       std::shared_ptr<stream> fwd_cpu_stream; | ||||||
|  |       fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine())); | ||||||
|       // Execute pooling op.
 |       // Execute pooling op.
 | ||||||
|       pooling_fwd->Execute(src_data, dst_data); |       pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream); | ||||||
| 
 | 
 | ||||||
|       // Pass min, max from input to output.
 |       // Pass min, max from input to output.
 | ||||||
|       if (int8_forward_inference) { |       if (int8_forward_inference) { | ||||||
| @ -240,8 +241,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { | |||||||
|               : memory::desc(diff_dst_dims, MklDnnType<T>(), |               : memory::desc(diff_dst_dims, MklDnnType<T>(), | ||||||
|                              this->data_format_mkldnn_); |                              this->data_format_mkldnn_); | ||||||
| 
 | 
 | ||||||
|       // Pass prop_kind::forward_training to create a forward primitive
 | // Pass prop_kind::forward_training to create a forward primitive
 | ||||||
|       // that is used in the backward pass.
 | // that is used in the backward pass.
 | ||||||
| #ifdef ENABLE_MKLDNN_V1 | #ifdef ENABLE_MKLDNN_V1 | ||||||
|       // TODO(DNNL): Find out what should we use src_md.data.format.
 |       // TODO(DNNL): Find out what should we use src_md.data.format.
 | ||||||
|       MklPoolingParams bwdParams( |       MklPoolingParams bwdParams( | ||||||
| @ -260,6 +261,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { | |||||||
|       MklPoolingBwdPrimitive<T>* pooling_bwd = |       MklPoolingBwdPrimitive<T>* pooling_bwd = | ||||||
|           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); |           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); | ||||||
| 
 | 
 | ||||||
|  |       std::shared_ptr<stream> bwd_cpu_stream; | ||||||
|  |       bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine())); | ||||||
|       Tensor* output_tensor = nullptr; |       Tensor* output_tensor = nullptr; | ||||||
|       this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), |       this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), | ||||||
|                                  orig_input_dims_mkl_order, |                                  orig_input_dims_mkl_order, | ||||||
| @ -286,7 +289,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { | |||||||
|       T* diff_src_data = output_tensor->flat<T>().data(); |       T* diff_src_data = output_tensor->flat<T>().data(); | ||||||
| 
 | 
 | ||||||
|       // Execute pooling op.
 |       // Execute pooling op.
 | ||||||
|       pooling_bwd->Execute(diff_dst_data, diff_src_data); |       pooling_bwd->Execute(diff_dst_data, diff_src_data, nullptr, | ||||||
|  |                            bwd_cpu_stream); | ||||||
|     } catch (mkldnn::error& e) { |     } catch (mkldnn::error& e) { | ||||||
|       string error_msg = "Status: " + std::to_string(e.status) + |       string error_msg = "Status: " + std::to_string(e.status) + | ||||||
|                          ", message: " + string(e.message) + ", in file " + |                          ", message: " + string(e.message) + ", in file " + | ||||||
|  | |||||||
| @ -167,10 +167,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { | |||||||
|       const T* src_data = input_tensor.flat<T>().data(); |       const T* src_data = input_tensor.flat<T>().data(); | ||||||
| 
 | 
 | ||||||
|       T* dst_data = output_tensor->flat<T>().data(); |       T* dst_data = output_tensor->flat<T>().data(); | ||||||
|  |       std::shared_ptr<stream> fwd_cpu_stream; | ||||||
|  |       fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine())); | ||||||
| 
 | 
 | ||||||
|       if (int8_forward_inference) { |       if (int8_forward_inference) { | ||||||
|         // Execute pooling op
 |         // Execute pooling op
 | ||||||
|         pooling_fwd->Execute(src_data, dst_data); |         pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream); | ||||||
| 
 | 
 | ||||||
|         // Pass min, max from input to output.
 |         // Pass min, max from input to output.
 | ||||||
|         const Tensor& min_input_t = MklGetInput(context, 1); |         const Tensor& min_input_t = MklGetInput(context, 1); | ||||||
| @ -197,7 +199,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { | |||||||
|         T* ws_data = |         T* ws_data = | ||||||
|             static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle()); |             static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle()); | ||||||
|         // Execute pooling op.
 |         // Execute pooling op.
 | ||||||
|         pooling_fwd->Execute(src_data, dst_data, ws_data); |         pooling_fwd->Execute(src_data, dst_data, ws_data, fwd_cpu_stream); | ||||||
|       } |       } | ||||||
|     } catch (mkldnn::error& e) { |     } catch (mkldnn::error& e) { | ||||||
|       string error_msg = "Status: " + std::to_string(e.status) + |       string error_msg = "Status: " + std::to_string(e.status) + | ||||||
| @ -322,6 +324,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { | |||||||
|       MklPoolingBwdPrimitive<T>* pooling_bwd = |       MklPoolingBwdPrimitive<T>* pooling_bwd = | ||||||
|           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); |           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); | ||||||
| 
 | 
 | ||||||
|  |       std::shared_ptr<stream> bwd_cpu_stream; | ||||||
|  |       bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine())); | ||||||
|       // Allocate output tensor and memory primitive.
 |       // Allocate output tensor and memory primitive.
 | ||||||
|       Tensor* output_tensor = nullptr; |       Tensor* output_tensor = nullptr; | ||||||
|       this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), |       this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), | ||||||
| @ -335,8 +339,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { | |||||||
|       if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd, |       if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd, | ||||||
|                                      pooling_bwd)) { |                                      pooling_bwd)) { | ||||||
|         grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); |         grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); | ||||||
|         grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( |         grad_dnn_data.CheckReorderToOpMem( | ||||||
|             GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd), cpu_engine_)); |             MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd), | ||||||
|  |                                    cpu_engine_), | ||||||
|  |             context); | ||||||
|         diff_dst_data = |         diff_dst_data = | ||||||
|             static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle()); |             static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle()); | ||||||
|       } else { |       } else { | ||||||
| @ -361,7 +367,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { | |||||||
|       T* diff_src_data = output_tensor->flat<T>().data(); |       T* diff_src_data = output_tensor->flat<T>().data(); | ||||||
| 
 | 
 | ||||||
|       // Execute pooling op.
 |       // Execute pooling op.
 | ||||||
|       pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data); |       pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data, | ||||||
|  |                            bwd_cpu_stream); | ||||||
|     } catch (mkldnn::error& e) { |     } catch (mkldnn::error& e) { | ||||||
|       string error_msg = "Status:" + std::to_string(e.status) + |       string error_msg = "Status:" + std::to_string(e.status) + | ||||||
|                          ", message: " + string(e.message) + ". in file " + |                          ", message: " + string(e.message) + ". in file " + | ||||||
|  | |||||||
| @ -23,7 +23,6 @@ limitations under the License. | |||||||
| #include "tensorflow/core/common_runtime/device.h" | #include "tensorflow/core/common_runtime/device.h" | ||||||
| #include "tensorflow/core/framework/bounds_check.h" | #include "tensorflow/core/framework/bounds_check.h" | ||||||
| #include "tensorflow/core/framework/kernel_shape_util.h" | #include "tensorflow/core/framework/kernel_shape_util.h" | ||||||
| 
 |  | ||||||
| namespace tensorflow { | namespace tensorflow { | ||||||
| using mkldnn::prop_kind; | using mkldnn::prop_kind; | ||||||
| 
 | 
 | ||||||
| @ -38,11 +37,11 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { | |||||||
|   context_.alg_kind = fwdParams.alg_kind; |   context_.alg_kind = fwdParams.alg_kind; | ||||||
|   context_.prop_kind = fwdParams.prop_kind; |   context_.prop_kind = fwdParams.prop_kind; | ||||||
| 
 | 
 | ||||||
|   // Create memory descriptor
 | // Create memory descriptor
 | ||||||
|   // FIXME: Pooling doesn't expose to get the src_primitive_desc,
 | // FIXME: Pooling doesn't expose to get the src_primitive_desc,
 | ||||||
|   //        so src format is currently hard-coded.
 | //        so src format is currently hard-coded.
 | ||||||
|   //        A utility function is used to do this,
 | //        A utility function is used to do this,
 | ||||||
|   //        which may be broken with future CPU architectures
 | //        which may be broken with future CPU architectures
 | ||||||
| #ifndef ENABLE_MKLDNN_V1 | #ifndef ENABLE_MKLDNN_V1 | ||||||
|   bool is_2d = (fwdParams.src_dims.size() == 4); |   bool is_2d = (fwdParams.src_dims.size() == 4); | ||||||
|   if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) |   if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) | ||||||
| @ -126,7 +125,8 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { | |||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
| void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, | void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, | ||||||
|                                         void* ws_data) { |                                         void* ws_data, | ||||||
|  |                                         std::shared_ptr<stream> fwd_stream) { | ||||||
|   context_.src_mem->set_data_handle( |   context_.src_mem->set_data_handle( | ||||||
|       static_cast<void*>(const_cast<T*>(src_data))); |       static_cast<void*>(const_cast<T*>(src_data))); | ||||||
|   context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); |   context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); | ||||||
| @ -138,10 +138,9 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| #ifdef ENABLE_MKLDNN_V1 | #ifdef ENABLE_MKLDNN_V1 | ||||||
|   execute_primitives(context_.fwd_primitives, context_.fwd_stream, |   execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args); | ||||||
|                      context_.net_args); |  | ||||||
| #else | #else | ||||||
|   context_.fwd_stream->submit(context_.fwd_primitives); |   fwd_stream->submit(context_.fwd_primitives); | ||||||
| #endif  // ENABLE_MKLDNN_V1
 | #endif  // ENABLE_MKLDNN_V1
 | ||||||
| 
 | 
 | ||||||
|   // Set back data handle.
 |   // Set back data handle.
 | ||||||
| @ -268,7 +267,8 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { | |||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
| void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, | void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, | ||||||
|                                         T* diff_src_data, const void* ws_data) { |                                         T* diff_src_data, const void* ws_data, | ||||||
|  |                                         std::shared_ptr<stream> bwd_stream) { | ||||||
|   context_.diff_dst_mem->set_data_handle( |   context_.diff_dst_mem->set_data_handle( | ||||||
|       static_cast<void*>(const_cast<T*>(diff_dst_data))); |       static_cast<void*>(const_cast<T*>(diff_dst_data))); | ||||||
|   context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); |   context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); | ||||||
| @ -278,10 +278,9 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, | |||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| #ifdef ENABLE_MKLDNN_V1 | #ifdef ENABLE_MKLDNN_V1 | ||||||
|   execute_primitives(context_.bwd_primitives, context_.bwd_stream, |   execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args); | ||||||
|                      context_.net_args); |  | ||||||
| #else | #else | ||||||
|   context_.bwd_stream->submit(context_.bwd_primitives); |   bwd_stream->submit(context_.bwd_primitives); | ||||||
| #endif  // ENABLE_MKLDNN_V1
 | #endif  // ENABLE_MKLDNN_V1
 | ||||||
| 
 | 
 | ||||||
|   // Set back data handle.
 |   // Set back data handle.
 | ||||||
|  | |||||||
| @ -86,8 +86,7 @@ template <typename T> | |||||||
| class MklPoolingFwdPrimitive : public MklPrimitive { | class MklPoolingFwdPrimitive : public MklPrimitive { | ||||||
|  public: |  public: | ||||||
|   explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) |   explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) | ||||||
|       : cpu_engine_(ENGINE_CPU, 0) { |       : MklPrimitive(engine(ENGINE_CPU, 0)) { | ||||||
|     context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); |  | ||||||
|     if (context_.fwd == nullptr) Setup(fwdParams); |     if (context_.fwd == nullptr) Setup(fwdParams); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -97,7 +96,8 @@ class MklPoolingFwdPrimitive : public MklPrimitive { | |||||||
|   //   src_data:  input data buffer of src
 |   //   src_data:  input data buffer of src
 | ||||||
|   //   ws_data:   output data buffer of workspace
 |   //   ws_data:   output data buffer of workspace
 | ||||||
|   //   dst_data:  output data buffer of dst
 |   //   dst_data:  output data buffer of dst
 | ||||||
|   void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr); |   void Execute(const T* src_data, T* dst_data, void* ws_data, | ||||||
|  |                std::shared_ptr<stream> fwd_stream); | ||||||
| 
 | 
 | ||||||
|   std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const { |   std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const { | ||||||
|     return context_.fwd_pd; |     return context_.fwd_pd; | ||||||
| @ -159,12 +159,10 @@ class MklPoolingFwdPrimitive : public MklPrimitive { | |||||||
|           fwd_pd(nullptr), |           fwd_pd(nullptr), | ||||||
|           src_md(nullptr), |           src_md(nullptr), | ||||||
|           dst_md(nullptr), |           dst_md(nullptr), | ||||||
|           fwd(nullptr), |           fwd(nullptr) {} | ||||||
|           fwd_stream(nullptr) {} |  | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|   struct PoolingFwdContext context_; |   struct PoolingFwdContext context_; | ||||||
|   engine cpu_engine_; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
| @ -229,8 +227,7 @@ template <typename T> | |||||||
| class MklPoolingBwdPrimitive : public MklPrimitive { | class MklPoolingBwdPrimitive : public MklPrimitive { | ||||||
|  public: |  public: | ||||||
|   explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) |   explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) | ||||||
|       : cpu_engine_(ENGINE_CPU, 0) { |       : MklPrimitive(engine(ENGINE_CPU, 0)) { | ||||||
|     context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_)); |  | ||||||
|     if (context_.bwd == nullptr) Setup(bwdParams); |     if (context_.bwd == nullptr) Setup(bwdParams); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
| @ -240,8 +237,8 @@ class MklPoolingBwdPrimitive : public MklPrimitive { | |||||||
|   //   diff_dst_data:  input data buffer of diff_dst
 |   //   diff_dst_data:  input data buffer of diff_dst
 | ||||||
|   //   diff_src_data:  output data buffer of diff_src
 |   //   diff_src_data:  output data buffer of diff_src
 | ||||||
|   //   ws_data:        input data buffer of workspace
 |   //   ws_data:        input data buffer of workspace
 | ||||||
|   void Execute(const T* diff_dst_data, T* diff_src_data, |   void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data, | ||||||
|                const void* ws_data = nullptr); |                std::shared_ptr<stream> bwd_stream); | ||||||
| 
 | 
 | ||||||
|  public: |  public: | ||||||
|   std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const { |   std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const { | ||||||
| @ -315,12 +312,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive { | |||||||
|           bwd_desc(nullptr), |           bwd_desc(nullptr), | ||||||
|           fwd_pd(nullptr), |           fwd_pd(nullptr), | ||||||
|           bwd_pd(nullptr), |           bwd_pd(nullptr), | ||||||
|           bwd(nullptr), |           bwd(nullptr) {} | ||||||
|           bwd_stream(nullptr) {} |  | ||||||
|   }; |   }; | ||||||
| 
 | 
 | ||||||
|   struct PoolingBwdContext context_; |   struct PoolingBwdContext context_; | ||||||
|   engine cpu_engine_; |  | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| template <typename T> | template <typename T> | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user