Merge pull request #39518 from Intel-tensorflow:sriniva2/threadpool_conv_bwd
PiperOrigin-RevId: 313445190 Change-Id: Ie0d353b3bb8162d84564b62c2757b51e23e0cb6e
This commit is contained in:
commit
23971655a4
@ -97,9 +97,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvBwdFilterPrimitive(
|
||||
const MklConvBwdFilterParams& convBwdFilterDims)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
// Create convolution backward filter primitive.
|
||||
if (context_.conv_bwd_filter == nullptr) {
|
||||
Setup(convBwdFilterDims);
|
||||
@ -114,7 +112,8 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
// diff_bias_data: output data buffer for diff_bias
|
||||
// diff_dst_data: input data buffer for diff_dst
|
||||
void Execute(const T* src_data, const T* diff_filter_data,
|
||||
const T* diff_bias_data, const T* diff_dst_data) {
|
||||
const T* diff_bias_data, const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_filter_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.diff_filter_mem->set_data_handle(
|
||||
@ -127,11 +126,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_filter_primitives,
|
||||
context_.bwd_filter_stream,
|
||||
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
|
||||
context_.bwd_filter_primitives_args);
|
||||
#else
|
||||
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
|
||||
bwd_filter_stream->submit(context_.bwd_filter_primitives);
|
||||
#endif
|
||||
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
@ -147,8 +145,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
// diff_filter_data: output data buffer of diff_filter
|
||||
// diff_dst_data: input data buffer of diff_dst
|
||||
void Execute(const T* src_data, const T* diff_filter_data,
|
||||
const T* diff_dst_data) {
|
||||
Execute(src_data, diff_filter_data, nullptr, diff_dst_data);
|
||||
const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_filter_stream) {
|
||||
Execute(src_data, diff_filter_data, nullptr, diff_dst_data,
|
||||
bwd_filter_stream);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
@ -223,8 +223,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
src_md(nullptr),
|
||||
diff_filter_md(nullptr),
|
||||
diff_bias_md(nullptr),
|
||||
diff_dst_md(nullptr),
|
||||
bwd_filter_stream(nullptr) {
|
||||
diff_dst_md(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -345,7 +344,6 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct ConvBwdFilterContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -600,8 +598,10 @@ class MklConvCustomBackpropFilterOp
|
||||
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
|
||||
if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) {
|
||||
src.SetUsrMem(fwd_src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_SRC,
|
||||
cpu_engine_),
|
||||
context);
|
||||
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
@ -612,8 +612,10 @@ class MklConvCustomBackpropFilterOp
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd,
|
||||
conv_bwd_filter)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST,
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
@ -646,18 +648,21 @@ class MklConvCustomBackpropFilterOp
|
||||
}
|
||||
|
||||
// Execute convolution backward filter.
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_filter->GetEngine()));
|
||||
if (bias_enabled) {
|
||||
T* diff_bias_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
|
||||
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
|
||||
diff_dst_data);
|
||||
diff_dst_data, bwd_cpu_stream);
|
||||
} else {
|
||||
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
|
||||
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data,
|
||||
bwd_cpu_stream);
|
||||
}
|
||||
|
||||
// Reorder diff_filter back to Tensorflow layout if necessary.
|
||||
if (diff_filter_reorder_required) {
|
||||
diff_filter.InsertReorderToUserMem();
|
||||
diff_filter.InsertReorderToUserMem(context);
|
||||
}
|
||||
|
||||
// Delete primitive since it is not cached.
|
||||
|
@ -99,9 +99,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvBwdInputPrimitive(
|
||||
const MklConvBwdInputParams& convBwdInputDims)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
// Create conv bwd input primitive
|
||||
if (context_.conv_bwd_input == nullptr) {
|
||||
Setup(convBwdInputDims);
|
||||
@ -116,7 +114,8 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
// diff_dst_data: input data buffer for dst
|
||||
// Bias does not matter here
|
||||
void Execute(const T* diff_src_data, const T* filter_data,
|
||||
const T* diff_dst_data) {
|
||||
const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_input_stream) {
|
||||
context_.diff_src_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(diff_src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
@ -125,10 +124,10 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream,
|
||||
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
|
||||
context_.bwd_input_primitives_args);
|
||||
#else
|
||||
context_.bwd_input_stream->submit(context_.bwd_input_primitives);
|
||||
bwd_input_stream->submit(context_.bwd_input_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// Set data handle back to DummyData.
|
||||
@ -180,7 +179,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<memory::desc> diff_dst_md;
|
||||
|
||||
// MKL-DNN pipeline for executing primitives.
|
||||
std::shared_ptr<mkldnn::stream> bwd_input_stream;
|
||||
std::vector<mkldnn::primitive> bwd_input_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
@ -203,8 +201,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
fwd_pd(nullptr),
|
||||
diff_src_md(nullptr),
|
||||
filter_md(nullptr),
|
||||
diff_dst_md(nullptr),
|
||||
bwd_input_stream(nullptr) {
|
||||
diff_dst_md(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -290,7 +287,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct ConvBwdInputContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -522,8 +518,10 @@ class MklConvCustomBackpropInputOp
|
||||
if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd,
|
||||
conv_bwd_input)) {
|
||||
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
|
||||
filter.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS,
|
||||
cpu_engine_),
|
||||
context);
|
||||
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
filter_data =
|
||||
@ -535,23 +533,29 @@ class MklConvCustomBackpropInputOp
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd,
|
||||
conv_bwd_input)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST,
|
||||
cpu_engine_),
|
||||
context);
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
std::shared_ptr<stream> bwd_cpu_stream;
|
||||
bwd_cpu_stream.reset(CreateStream(context, conv_bwd_input->GetEngine()));
|
||||
// Execute conv bwd input primitive.
|
||||
if (!eager_mode) {
|
||||
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
|
||||
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data,
|
||||
bwd_cpu_stream);
|
||||
} else {
|
||||
// In eager mode we first write the output to temporary
|
||||
// buffer in MKL format. Then we convert the data to TF format.
|
||||
T* tmp_data =
|
||||
static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
|
||||
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
|
||||
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data,
|
||||
bwd_cpu_stream);
|
||||
auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
|
||||
@ -563,7 +567,7 @@ class MklConvCustomBackpropInputOp
|
||||
memory* dst_data_mem =
|
||||
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data);
|
||||
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
|
||||
cpu_engine_);
|
||||
cpu_engine_, context);
|
||||
}
|
||||
|
||||
// Delete primitive since it is not cached.
|
||||
|
Loading…
Reference in New Issue
Block a user