Merge pull request #39519 from Intel-tensorflow:sriniva2/threadpool_pooling
PiperOrigin-RevId: 313437523 Change-Id: I9ebb625cc949eef464de2fcbb0ce77635e7c41e8
This commit is contained in:
commit
54e57d69d2
@ -136,9 +136,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
|
|||||||
const T* src_data = input_tensor.flat<T>().data();
|
const T* src_data = input_tensor.flat<T>().data();
|
||||||
|
|
||||||
T* dst_data = output_tensor->flat<T>().data();
|
T* dst_data = output_tensor->flat<T>().data();
|
||||||
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
|
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
|
||||||
// Execute pooling op.
|
// Execute pooling op.
|
||||||
pooling_fwd->Execute(src_data, dst_data);
|
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
|
||||||
|
|
||||||
// Pass min, max from input to output.
|
// Pass min, max from input to output.
|
||||||
if (int8_forward_inference) {
|
if (int8_forward_inference) {
|
||||||
@ -240,8 +241,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
|||||||
: memory::desc(diff_dst_dims, MklDnnType<T>(),
|
: memory::desc(diff_dst_dims, MklDnnType<T>(),
|
||||||
this->data_format_mkldnn_);
|
this->data_format_mkldnn_);
|
||||||
|
|
||||||
// Pass prop_kind::forward_training to create a forward primitive
|
// Pass prop_kind::forward_training to create a forward primitive
|
||||||
// that is used in the backward pass.
|
// that is used in the backward pass.
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
// TODO(DNNL): Find out what should we use src_md.data.format.
|
// TODO(DNNL): Find out what should we use src_md.data.format.
|
||||||
MklPoolingParams bwdParams(
|
MklPoolingParams bwdParams(
|
||||||
@ -260,6 +261,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
|||||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||||
|
|
||||||
|
std::shared_ptr<stream> bwd_cpu_stream;
|
||||||
|
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
|
||||||
Tensor* output_tensor = nullptr;
|
Tensor* output_tensor = nullptr;
|
||||||
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
||||||
orig_input_dims_mkl_order,
|
orig_input_dims_mkl_order,
|
||||||
@ -286,7 +289,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
|||||||
T* diff_src_data = output_tensor->flat<T>().data();
|
T* diff_src_data = output_tensor->flat<T>().data();
|
||||||
|
|
||||||
// Execute pooling op.
|
// Execute pooling op.
|
||||||
pooling_bwd->Execute(diff_dst_data, diff_src_data);
|
pooling_bwd->Execute(diff_dst_data, diff_src_data, nullptr,
|
||||||
|
bwd_cpu_stream);
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status: " + std::to_string(e.status) +
|
string error_msg = "Status: " + std::to_string(e.status) +
|
||||||
", message: " + string(e.message) + ", in file " +
|
", message: " + string(e.message) + ", in file " +
|
||||||
|
|||||||
@ -167,10 +167,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
|||||||
const T* src_data = input_tensor.flat<T>().data();
|
const T* src_data = input_tensor.flat<T>().data();
|
||||||
|
|
||||||
T* dst_data = output_tensor->flat<T>().data();
|
T* dst_data = output_tensor->flat<T>().data();
|
||||||
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
|
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
|
||||||
|
|
||||||
if (int8_forward_inference) {
|
if (int8_forward_inference) {
|
||||||
// Execute pooling op
|
// Execute pooling op
|
||||||
pooling_fwd->Execute(src_data, dst_data);
|
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
|
||||||
|
|
||||||
// Pass min, max from input to output.
|
// Pass min, max from input to output.
|
||||||
const Tensor& min_input_t = MklGetInput(context, 1);
|
const Tensor& min_input_t = MklGetInput(context, 1);
|
||||||
@ -197,7 +199,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
|
|||||||
T* ws_data =
|
T* ws_data =
|
||||||
static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
|
static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
|
||||||
// Execute pooling op.
|
// Execute pooling op.
|
||||||
pooling_fwd->Execute(src_data, dst_data, ws_data);
|
pooling_fwd->Execute(src_data, dst_data, ws_data, fwd_cpu_stream);
|
||||||
}
|
}
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status: " + std::to_string(e.status) +
|
string error_msg = "Status: " + std::to_string(e.status) +
|
||||||
@ -322,6 +324,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
|||||||
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
MklPoolingBwdPrimitive<T>* pooling_bwd =
|
||||||
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||||
|
|
||||||
|
std::shared_ptr<stream> bwd_cpu_stream;
|
||||||
|
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
|
||||||
// Allocate output tensor and memory primitive.
|
// Allocate output tensor and memory primitive.
|
||||||
Tensor* output_tensor = nullptr;
|
Tensor* output_tensor = nullptr;
|
||||||
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
|
||||||
@ -335,8 +339,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
|||||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
|
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
|
||||||
pooling_bwd)) {
|
pooling_bwd)) {
|
||||||
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
|
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
|
||||||
grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
grad_dnn_data.CheckReorderToOpMem(
|
||||||
GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd), cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd),
|
||||||
|
cpu_engine_),
|
||||||
|
context);
|
||||||
diff_dst_data =
|
diff_dst_data =
|
||||||
static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle());
|
static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle());
|
||||||
} else {
|
} else {
|
||||||
@ -361,7 +367,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
|
|||||||
T* diff_src_data = output_tensor->flat<T>().data();
|
T* diff_src_data = output_tensor->flat<T>().data();
|
||||||
|
|
||||||
// Execute pooling op.
|
// Execute pooling op.
|
||||||
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
|
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data,
|
||||||
|
bwd_cpu_stream);
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = "Status:" + std::to_string(e.status) +
|
string error_msg = "Status:" + std::to_string(e.status) +
|
||||||
", message: " + string(e.message) + ". in file " +
|
", message: " + string(e.message) + ". in file " +
|
||||||
|
|||||||
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/framework/bounds_check.h"
|
#include "tensorflow/core/framework/bounds_check.h"
|
||||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
using mkldnn::prop_kind;
|
using mkldnn::prop_kind;
|
||||||
|
|
||||||
@ -38,11 +37,11 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
|||||||
context_.alg_kind = fwdParams.alg_kind;
|
context_.alg_kind = fwdParams.alg_kind;
|
||||||
context_.prop_kind = fwdParams.prop_kind;
|
context_.prop_kind = fwdParams.prop_kind;
|
||||||
|
|
||||||
// Create memory descriptor
|
// Create memory descriptor
|
||||||
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
|
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
|
||||||
// so src format is currently hard-coded.
|
// so src format is currently hard-coded.
|
||||||
// A utility function is used to do this,
|
// A utility function is used to do this,
|
||||||
// which may be broken with future CPU architectures
|
// which may be broken with future CPU architectures
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
bool is_2d = (fwdParams.src_dims.size() == 4);
|
bool is_2d = (fwdParams.src_dims.size() == 4);
|
||||||
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
|
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
|
||||||
@ -126,7 +125,8 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||||
void* ws_data) {
|
void* ws_data,
|
||||||
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)));
|
static_cast<void*>(const_cast<T*>(src_data)));
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||||
@ -138,10 +138,9 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
|
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||||
context_.net_args);
|
|
||||||
#else
|
#else
|
||||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
fwd_stream->submit(context_.fwd_primitives);
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
// Set back data handle.
|
// Set back data handle.
|
||||||
@ -268,7 +267,8 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||||
T* diff_src_data, const void* ws_data) {
|
T* diff_src_data, const void* ws_data,
|
||||||
|
std::shared_ptr<stream> bwd_stream) {
|
||||||
context_.diff_dst_mem->set_data_handle(
|
context_.diff_dst_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||||
@ -278,10 +278,9 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
|
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
||||||
context_.net_args);
|
|
||||||
#else
|
#else
|
||||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
bwd_stream->submit(context_.bwd_primitives);
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
// Set back data handle.
|
// Set back data handle.
|
||||||
|
|||||||
@ -86,8 +86,7 @@ template <typename T>
|
|||||||
class MklPoolingFwdPrimitive : public MklPrimitive {
|
class MklPoolingFwdPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
|
explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
|
||||||
: cpu_engine_(ENGINE_CPU, 0) {
|
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
|
||||||
if (context_.fwd == nullptr) Setup(fwdParams);
|
if (context_.fwd == nullptr) Setup(fwdParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,7 +96,8 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
|
|||||||
// src_data: input data buffer of src
|
// src_data: input data buffer of src
|
||||||
// ws_data: output data buffer of workspace
|
// ws_data: output data buffer of workspace
|
||||||
// dst_data: output data buffer of dst
|
// dst_data: output data buffer of dst
|
||||||
void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
|
void Execute(const T* src_data, T* dst_data, void* ws_data,
|
||||||
|
std::shared_ptr<stream> fwd_stream);
|
||||||
|
|
||||||
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
||||||
return context_.fwd_pd;
|
return context_.fwd_pd;
|
||||||
@ -159,12 +159,10 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
|
|||||||
fwd_pd(nullptr),
|
fwd_pd(nullptr),
|
||||||
src_md(nullptr),
|
src_md(nullptr),
|
||||||
dst_md(nullptr),
|
dst_md(nullptr),
|
||||||
fwd(nullptr),
|
fwd(nullptr) {}
|
||||||
fwd_stream(nullptr) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PoolingFwdContext context_;
|
struct PoolingFwdContext context_;
|
||||||
engine cpu_engine_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -229,8 +227,7 @@ template <typename T>
|
|||||||
class MklPoolingBwdPrimitive : public MklPrimitive {
|
class MklPoolingBwdPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
|
explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
|
||||||
: cpu_engine_(ENGINE_CPU, 0) {
|
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||||
context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
|
||||||
if (context_.bwd == nullptr) Setup(bwdParams);
|
if (context_.bwd == nullptr) Setup(bwdParams);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -240,8 +237,8 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
|
|||||||
// diff_dst_data: input data buffer of diff_dst
|
// diff_dst_data: input data buffer of diff_dst
|
||||||
// diff_src_data: output data buffer of diff_src
|
// diff_src_data: output data buffer of diff_src
|
||||||
// ws_data: input data buffer of workspace
|
// ws_data: input data buffer of workspace
|
||||||
void Execute(const T* diff_dst_data, T* diff_src_data,
|
void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data,
|
||||||
const void* ws_data = nullptr);
|
std::shared_ptr<stream> bwd_stream);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
|
||||||
@ -315,12 +312,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
|
|||||||
bwd_desc(nullptr),
|
bwd_desc(nullptr),
|
||||||
fwd_pd(nullptr),
|
fwd_pd(nullptr),
|
||||||
bwd_pd(nullptr),
|
bwd_pd(nullptr),
|
||||||
bwd(nullptr),
|
bwd(nullptr) {}
|
||||||
bwd_stream(nullptr) {}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct PoolingBwdContext context_;
|
struct PoolingBwdContext context_;
|
||||||
engine cpu_engine_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user