Merge pull request #39519 from Intel-tensorflow:sriniva2/threadpool_pooling
PiperOrigin-RevId: 313437523 Change-Id: I9ebb625cc949eef464de2fcbb0ce77635e7c41e8
This commit is contained in:
commit
54e57d69d2
@ -136,9 +136,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
const T* src_data = input_tensor.flat<T>().data();
|
||||
|
||||
T* dst_data = output_tensor->flat<T>().data();
|
||||
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
|
||||
// Execute pooling op.
|
||||
pooling_fwd->Execute(src_data, dst_data);
|
||||
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
|
||||
|
||||
// Pass min, max from input to output.
|
||||
if (int8_forward_inference) {
|
||||
@ -240,8 +241,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
: memory::desc(diff_dst_dims, MklDnnType<T>(),
|
||||
this->data_format_mkldnn_);
|
||||
|
||||
// Pass prop_kind::forward_training to create a forward primitive
|
||||
// that is used in the backward pass.
|
||||
// Pass prop_kind::forward_training to create a forward primitive
|
||||
// that is used in the backward pass.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// TODO(DNNL): Find out what should we use src_md.data.format.
|
||||
MklPoolingParams bwdParams(
|
||||
@ -260,6 +261,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
|
||||
Tensor* output_tensor = nullptr;
|
||||
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
||||
orig_input_dims_mkl_order,
|
||||
@ -286,7 +289,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
T* diff_src_data = output_tensor->flat<T>().data();
|
||||
|
||||
// Execute pooling op.
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data);
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data, nullptr,
|
||||
bwd_cpu_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
|
@ -167,10 +167,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
const T* src_data = input_tensor.flat<T>().data();
|
||||
|
||||
T* dst_data = output_tensor->flat<T>().data();
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
|
||||
|
||||
if (int8_forward_inference) {
|
||||
// Execute pooling op
|
||||
pooling_fwd->Execute(src_data, dst_data);
|
||||
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
|
||||
|
||||
// Pass min, max from input to output.
|
||||
const Tensor& min_input_t = MklGetInput(context, 1);
|
||||
@ -197,7 +199,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
||||
T* ws_data =
|
||||
static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
|
||||
// Execute pooling op.
|
||||
pooling_fwd->Execute(src_data, dst_data, ws_data);
|
||||
pooling_fwd->Execute(src_data, dst_data, ws_data, fwd_cpu_stream);
|
||||
}
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
@ -322,6 +324,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
|
||||
// Allocate output tensor and memory primitive.
|
||||
Tensor* output_tensor = nullptr;
|
||||
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
||||
@ -335,8 +339,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
|
||||
pooling_bwd)) {
|
||||
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
|
||||
grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd), cpu_engine_));
|
||||
grad_dnn_data.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data =
|
||||
static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
@ -361,7 +367,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
||||
T* diff_src_data = output_tensor->flat<T>().data();
|
||||
|
||||
// Execute pooling op.
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
|
||||
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data,
|
||||
bwd_cpu_stream);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status:" + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ". in file " +
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using mkldnn::prop_kind;
|
||||
|
||||
@ -38,11 +37,11 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
context_.alg_kind = fwdParams.alg_kind;
|
||||
context_.prop_kind = fwdParams.prop_kind;
|
||||
|
||||
// Create memory descriptor
|
||||
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
|
||||
// so src format is currently hard-coded.
|
||||
// A utility function is used to do this,
|
||||
// which may be broken with future CPU architectures
|
||||
// Create memory descriptor
|
||||
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
|
||||
// so src format is currently hard-coded.
|
||||
// A utility function is used to do this,
|
||||
// which may be broken with future CPU architectures
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
bool is_2d = (fwdParams.src_dims.size() == 4);
|
||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
|
||||
@ -126,7 +125,8 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
||||
|
||||
template <typename T>
|
||||
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||
void* ws_data) {
|
||||
void* ws_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
@ -138,10 +138,9 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
||||
context_.net_args);
|
||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Set back data handle.
|
||||
@ -268,7 +267,8 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
|
||||
|
||||
template <typename T>
|
||||
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||
T* diff_src_data, const void* ws_data) {
|
||||
T* diff_src_data, const void* ws_data,
|
||||
std::shared_ptr<stream> bwd_stream) {
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||
@ -278,10 +278,9 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||
}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
|
||||
context_.net_args);
|
||||
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
||||
#else
|
||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
||||
bwd_stream->submit(context_.bwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Set back data handle.
|
||||
|
@ -86,8 +86,7 @@ template <typename T>
|
||||
class MklPoolingFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
if (context_.fwd == nullptr) Setup(fwdParams);
|
||||
}
|
||||
|
||||
@ -97,7 +96,8 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
|
||||
// src_data: input data buffer of src
|
||||
// ws_data: output data buffer of workspace
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
|
||||
void Execute(const T* src_data, T* dst_data, void* ws_data,
|
||||
std::shared_ptr<stream> fwd_stream);
|
||||
|
||||
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
||||
return context_.fwd_pd;
|
||||
@ -159,12 +159,10 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
|
||||
fwd_pd(nullptr),
|
||||
src_md(nullptr),
|
||||
dst_md(nullptr),
|
||||
fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
fwd(nullptr) {}
|
||||
};
|
||||
|
||||
struct PoolingFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -229,8 +227,7 @@ template <typename T>
|
||||
class MklPoolingBwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
if (context_.bwd == nullptr) Setup(bwdParams);
|
||||
}
|
||||
|
||||
@ -240,8 +237,8 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
|
||||
// diff_dst_data: input data buffer of diff_dst
|
||||
// diff_src_data: output data buffer of diff_src
|
||||
// ws_data: input data buffer of workspace
|
||||
void Execute(const T* diff_dst_data, T* diff_src_data,
|
||||
const void* ws_data = nullptr);
|
||||
void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data,
|
||||
std::shared_ptr<stream> bwd_stream);
|
||||
|
||||
public:
|
||||
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
||||
@ -315,12 +312,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
|
||||
bwd_desc(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
bwd_pd(nullptr),
|
||||
bwd(nullptr),
|
||||
bwd_stream(nullptr) {}
|
||||
bwd(nullptr) {}
|
||||
};
|
||||
|
||||
struct PoolingBwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
Loading…
x
Reference in New Issue
Block a user