Merge pull request #39518 from Intel-tensorflow:sriniva2/threadpool_conv_bwd

PiperOrigin-RevId: 313445190
Change-Id: Ie0d353b3bb8162d84564b62c2757b51e23e0cb6e
This commit is contained in:
TensorFlower Gardener 2020-05-27 12:55:35 -07:00
commit 23971655a4
2 changed files with 45 additions and 36 deletions

View File

@ -97,9 +97,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
explicit MklConvBwdFilterPrimitive(
const MklConvBwdFilterParams& convBwdFilterDims)
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
// Create convolution backward filter primitive.
if (context_.conv_bwd_filter == nullptr) {
Setup(convBwdFilterDims);
@ -114,7 +112,8 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
// diff_bias_data: output data buffer for diff_bias
// diff_dst_data: input data buffer for diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_bias_data, const T* diff_dst_data) {
const T* diff_bias_data, const T* diff_dst_data,
std::shared_ptr<stream> bwd_filter_stream) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_filter_mem->set_data_handle(
@ -127,11 +126,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
static_cast<void*>(const_cast<T*>(diff_dst_data)));
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_filter_primitives,
context_.bwd_filter_stream,
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
context_.bwd_filter_primitives_args);
#else
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
bwd_filter_stream->submit(context_.bwd_filter_primitives);
#endif
context_.src_mem->set_data_handle(DummyData);
@ -147,8 +145,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
// diff_filter_data: output data buffer of diff_filter
// diff_dst_data: input data buffer of diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_dst_data) {
Execute(src_data, diff_filter_data, nullptr, diff_dst_data);
const T* diff_dst_data,
std::shared_ptr<stream> bwd_filter_stream) {
Execute(src_data, diff_filter_data, nullptr, diff_dst_data,
bwd_filter_stream);
}
#ifndef ENABLE_MKLDNN_V1
@ -223,8 +223,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
src_md(nullptr),
diff_filter_md(nullptr),
diff_bias_md(nullptr),
diff_dst_md(nullptr),
bwd_filter_stream(nullptr) {
diff_dst_md(nullptr) {
}
};
@ -345,7 +344,6 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
}
struct ConvBwdFilterContext context_;
engine cpu_engine_;
};
template <typename T>
@ -600,8 +598,10 @@ class MklConvCustomBackpropFilterOp
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_));
src.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_SRC,
cpu_engine_),
context);
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
@ -612,8 +612,10 @@ class MklConvCustomBackpropFilterOp
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd,
conv_bwd_filter)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
diff_dst.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST,
cpu_engine_),
context);
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
@ -646,18 +648,21 @@ class MklConvCustomBackpropFilterOp
}
// Execute convolution backward filter.
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_filter->GetEngine()));
if (bias_enabled) {
T* diff_bias_data =
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
diff_dst_data);
diff_dst_data, bwd_cpu_stream);
} else {
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data,
bwd_cpu_stream);
}
// Reorder diff_filter back to Tensorflow layout if necessary.
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
diff_filter.InsertReorderToUserMem(context);
}
// Delete primitive since it is not cached.

View File

@ -99,9 +99,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
public:
explicit MklConvBwdInputPrimitive(
const MklConvBwdInputParams& convBwdInputDims)
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_));
: MklPrimitive(engine(ENGINE_CPU, 0)) {
// Create conv bwd input primitive
if (context_.conv_bwd_input == nullptr) {
Setup(convBwdInputDims);
@ -116,7 +114,8 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
// diff_dst_data: input data buffer for dst
// Bias does not matter here
void Execute(const T* diff_src_data, const T* filter_data,
const T* diff_dst_data) {
const T* diff_dst_data,
std::shared_ptr<stream> bwd_input_stream) {
context_.diff_src_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_src_data)));
context_.filter_mem->set_data_handle(
@ -125,10 +124,10 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
static_cast<T*>(const_cast<T*>(diff_dst_data)));
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream,
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
context_.bwd_input_primitives_args);
#else
context_.bwd_input_stream->submit(context_.bwd_input_primitives);
bwd_input_stream->submit(context_.bwd_input_primitives);
#endif // ENABLE_MKLDNN_V1
// Set data handle back to DummyData.
@ -180,7 +179,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
std::shared_ptr<memory::desc> diff_dst_md;
// MKL-DNN pipeline for executing primitives.
std::shared_ptr<mkldnn::stream> bwd_input_stream;
std::vector<mkldnn::primitive> bwd_input_primitives;
#ifdef ENABLE_MKLDNN_V1
@ -203,8 +201,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
fwd_pd(nullptr),
diff_src_md(nullptr),
filter_md(nullptr),
diff_dst_md(nullptr),
bwd_input_stream(nullptr) {
diff_dst_md(nullptr) {
}
};
@ -290,7 +287,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
}
struct ConvBwdInputContext context_;
engine cpu_engine_;
};
template <typename T>
@ -522,8 +518,10 @@ class MklConvCustomBackpropInputOp
if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd,
conv_bwd_input)) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
filter.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS,
cpu_engine_),
context);
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
} else {
filter_data =
@ -535,23 +533,29 @@ class MklConvCustomBackpropInputOp
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd,
conv_bwd_input)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
diff_dst.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST,
cpu_engine_),
context);
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
}
std::shared_ptr<stream> bwd_cpu_stream;
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine()));
// Execute conv bwd input primitive.
if (!eager_mode) {
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data,
bwd_cpu_stream);
} else {
// In eager mode we first write the output to temporary
// buffer in MKL format. Then we convert the data to TF format.
T* tmp_data =
static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data,
bwd_cpu_stream);
auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
#ifndef ENABLE_MKLDNN_V1
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
@ -563,7 +567,7 @@ class MklConvCustomBackpropInputOp
memory* dst_data_mem =
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data);
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
cpu_engine_);
cpu_engine_, context);
}
// Delete primitive since it is not cached.