Merge pull request #39519 from Intel-tensorflow:sriniva2/threadpool_pooling

PiperOrigin-RevId: 313437523
Change-Id: I9ebb625cc949eef464de2fcbb0ce77635e7c41e8
This commit is contained in:
TensorFlower Gardener 2020-05-27 12:16:52 -07:00
commit 54e57d69d2
4 changed files with 42 additions and 37 deletions

View File

@ -136,9 +136,10 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
const T* src_data = input_tensor.flat<T>().data();
T* dst_data = output_tensor->flat<T>().data();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
// Execute pooling op.
pooling_fwd->Execute(src_data, dst_data);
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
// Pass min, max from input to output.
if (int8_forward_inference) {
@ -240,8 +241,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
: memory::desc(diff_dst_dims, MklDnnType<T>(),
this->data_format_mkldnn_);
// Pass prop_kind::forward_training to create a forward primitive
// that is used in the backward pass.
// Pass prop_kind::forward_training to create a forward primitive
// that is used in the backward pass.
#ifdef ENABLE_MKLDNN_V1
// TODO(DNNL): Find out what should we use src_md.data.format.
MklPoolingParams bwdParams(
@ -260,6 +261,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolingBwdPrimitive<T>* pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
orig_input_dims_mkl_order,
@ -286,7 +289,8 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
T* diff_src_data = output_tensor->flat<T>().data();
// Execute pooling op.
pooling_bwd->Execute(diff_dst_data, diff_src_data);
pooling_bwd->Execute(diff_dst_data, diff_src_data, nullptr,
bwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +

View File

@ -167,10 +167,12 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
const T* src_data = input_tensor.flat<T>().data();
T* dst_data = output_tensor->flat<T>().data();
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, pooling_fwd->GetEngine()));
if (int8_forward_inference) {
// Execute pooling op
pooling_fwd->Execute(src_data, dst_data);
pooling_fwd->Execute(src_data, dst_data, nullptr, fwd_cpu_stream);
// Pass min, max from input to output.
const Tensor& min_input_t = MklGetInput(context, 1);
@ -197,7 +199,7 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
T* ws_data =
static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
// Execute pooling op.
pooling_fwd->Execute(src_data, dst_data, ws_data);
pooling_fwd->Execute(src_data, dst_data, ws_data, fwd_cpu_stream);
}
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
@ -322,6 +324,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolingBwdPrimitive<T>* pooling_bwd =
MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, pooling_bwd->GetEngine()));
// Allocate output tensor and memory primitive.
Tensor* output_tensor = nullptr;
this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
@ -335,8 +339,10 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, pooling_bwd_pd,
pooling_bwd)) {
grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd), cpu_engine_));
grad_dnn_data.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(GET_DIFF_DST_DESC_FROM_OP_PD(pooling_bwd_pd),
cpu_engine_),
context);
diff_dst_data =
static_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle());
} else {
@ -361,7 +367,8 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
T* diff_src_data = output_tensor->flat<T>().data();
// Execute pooling op.
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data,
bwd_cpu_stream);
} catch (mkldnn::error& e) {
string error_msg = "Status:" + std::to_string(e.status) +
", message: " + string(e.message) + ". in file " +

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
namespace tensorflow {
using mkldnn::prop_kind;
@ -38,11 +37,11 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
context_.alg_kind = fwdParams.alg_kind;
context_.prop_kind = fwdParams.prop_kind;
// Create memory descriptor
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
// Create memory descriptor
// FIXME: Pooling doesn't expose to get the src_primitive_desc,
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
#ifndef ENABLE_MKLDNN_V1
bool is_2d = (fwdParams.src_dims.size() == 4);
if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value)
@ -126,7 +125,8 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
template <typename T>
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
void* ws_data) {
void* ws_data,
std::shared_ptr<stream> fwd_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
@ -138,10 +138,9 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
}
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.fwd_primitives, context_.fwd_stream,
context_.net_args);
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
#else
context_.fwd_stream->submit(context_.fwd_primitives);
fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
// Set back data handle.
@ -268,7 +267,8 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
template <typename T>
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
T* diff_src_data, const void* ws_data) {
T* diff_src_data, const void* ws_data,
std::shared_ptr<stream> bwd_stream) {
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
@ -278,10 +278,9 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
}
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_primitives, context_.bwd_stream,
context_.net_args);
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
#else
context_.bwd_stream->submit(context_.bwd_primitives);
bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
// Set back data handle.

View File

@ -86,8 +86,7 @@ template <typename T>
class MklPoolingFwdPrimitive : public MklPrimitive {
public:
explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
if (context_.fwd == nullptr) Setup(fwdParams);
}
@ -97,7 +96,8 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
// src_data: input data buffer of src
// ws_data: output data buffer of workspace
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
void Execute(const T* src_data, T* dst_data, void* ws_data,
std::shared_ptr<stream> fwd_stream);
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
return context_.fwd_pd;
@ -159,12 +159,10 @@ class MklPoolingFwdPrimitive : public MklPrimitive {
fwd_pd(nullptr),
src_md(nullptr),
dst_md(nullptr),
fwd(nullptr),
fwd_stream(nullptr) {}
fwd(nullptr) {}
};
struct PoolingFwdContext context_;
engine cpu_engine_;
};
template <typename T>
@ -229,8 +227,7 @@ template <typename T>
class MklPoolingBwdPrimitive : public MklPrimitive {
public:
explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
if (context_.bwd == nullptr) Setup(bwdParams);
}
@ -240,8 +237,8 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
// diff_dst_data: input data buffer of diff_dst
// diff_src_data: output data buffer of diff_src
// ws_data: input data buffer of workspace
void Execute(const T* diff_dst_data, T* diff_src_data,
const void* ws_data = nullptr);
void Execute(const T* diff_dst_data, T* diff_src_data, const void* ws_data,
std::shared_ptr<stream> bwd_stream);
public:
std::shared_ptr<PoolingFwdPd> GetPoolingFwdPd() const {
@ -315,12 +312,10 @@ class MklPoolingBwdPrimitive : public MklPrimitive {
bwd_desc(nullptr),
fwd_pd(nullptr),
bwd_pd(nullptr),
bwd(nullptr),
bwd_stream(nullptr) {}
bwd(nullptr) {}
};
struct PoolingBwdContext context_;
engine cpu_engine_;
};
template <typename T>