Merge pull request #36630 from Intel-tensorflow:dnn10_conv_bwd

PiperOrigin-RevId: 295228960
Change-Id: I2bc99b61cd0a8dc9b9771a0e65e393fd6176d5d9
This commit is contained in:
TensorFlower Gardener 2020-02-14 14:40:09 -08:00
commit 81323b7924
3 changed files with 392 additions and 306 deletions

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -46,8 +47,12 @@ using mkldnn::prop_kind;
using mkldnn::stream;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
using ConvBwdFilterDesc = mkldnn::convolution_backward_weights::desc;
using ConvBwdFilterPd = mkldnn::convolution_backward_weights::primitive_desc;
struct MklConvBwdFilterParams {
memory::dims src_dims;
memory::dims diff_filter_dims;
@ -80,9 +85,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
explicit MklConvBwdFilterPrimitive(
const MklConvBwdFilterParams& convBwdFilterDims)
: cpu_engine_(engine::cpu, 0) {
context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
// create conv primitive
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_));
// Create convolution backward filter primitive.
if (context_.conv_bwd_filter == nullptr) {
Setup(convBwdFilterDims);
}
@ -90,106 +96,111 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
~MklConvBwdFilterPrimitive() {}
// Convolution backward weights with bias
// src_data: input data buffer of src
// diff_filter_data: output data buffer of diff_filter
// diff_bias_data: output data buffer of diff_bias
// diff_dst_data: input data buffer of diff_dst
// Convolution backward weights execution with bias
// src_data: input data buffer for src
// diff_filter_data: output data buffer for diff_filter
// diff_bias_data: output data buffer for diff_bias
// diff_dst_data: input data buffer for diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_bias_data, const T* diff_dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_filter_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_filter_data)));
context_.diff_bias_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_bias_data)));
if (diff_bias_data != nullptr) {
context_.diff_bias_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_bias_data)));
}
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_filter_primitives,
context_.bwd_filter_stream,
context_.bwd_filter_primitives_args);
#else
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
#endif
context_.src_mem->set_data_handle(DummyData);
context_.diff_filter_mem->set_data_handle(DummyData);
context_.diff_bias_mem->set_data_handle(DummyData);
if (diff_bias_data != nullptr) {
context_.diff_bias_mem->set_data_handle(DummyData);
}
context_.diff_dst_mem->set_data_handle(DummyData);
return;
}
// Convolution backward weights without bias
// Convolution backward weights without bias.
// src_data: input data buffer of src
// diff_filter_data: output data buffer of diff_filter
// diff_dst_data: input data buffer of diff_dst
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_filter_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_filter_data)));
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
context_.src_mem->set_data_handle(DummyData);
context_.diff_filter_mem->set_data_handle(DummyData);
context_.diff_dst_mem->set_data_handle(DummyData);
return;
Execute(src_data, diff_filter_data, nullptr, diff_dst_data);
}
#ifndef ENABLE_MKLDNN_V1
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetDiffDstMemoryFormat() const {
return context_.diff_dst_fmt;
}
memory::format GetDiffFilterMemoryFormat() const {
return context_.diff_filter_fmt;
}
#endif
// convolution primitive
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
GetPrimitiveDesc() const {
std::shared_ptr<ConvBwdFilterPd> GetPrimitiveDesc() const {
return context_.bwd_filter_pd;
}
private:
// Primitive reuse context for Conv2D bwd filter op
// Primitive reuse context for Conv2D backward filter op.
struct ConvBwdFilterContext {
// expected memory format for this primitive instance
#ifndef ENABLE_MKLDNN_V1
// Expected memory format for this primitive instance
memory::format src_fmt;
memory::format diff_dst_fmt;
memory::format diff_filter_fmt;
#endif // !ENABLE_MKLDNN_V1
// convolution bwd input primitive
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
bwd_filter_pd;
std::shared_ptr<mkldnn::primitive> conv_bwd_filter;
// MKLDNN memory
// MKL-DNN memory for inputs and outputs.
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> diff_filter_mem;
std::shared_ptr<mkldnn::memory> diff_bias_mem;
std::shared_ptr<mkldnn::memory> diff_dst_mem;
// desc & prmitive desc
std::shared_ptr<mkldnn::convolution_backward_weights::desc> bwd_filter_desc;
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
// Primitive descriptor and descriptor for convolution backward filter.
std::shared_ptr<ConvBwdFilterPd> bwd_filter_pd;
std::shared_ptr<ConvBwdFilterDesc> bwd_filter_desc;
// memory desc: forward & backward can share same memory desc
// Primitive descriptor and descriptor for convolution forward.
std::shared_ptr<ConvFwdPd> fwd_pd;
std::shared_ptr<ConvFwdDesc> fwd_desc;
// Convolution backward filter primitive.
std::shared_ptr<mkldnn::primitive> conv_bwd_filter;
// Memory descriptors: forward & backward share the same memory descriptors
std::shared_ptr<mkldnn::memory::desc> src_md;
std::shared_ptr<mkldnn::memory::desc> diff_filter_md;
std::shared_ptr<mkldnn::memory::desc> diff_bias_md;
std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
// MKL pipeline
// MKL-DNN pipeline for executing primitives.
std::shared_ptr<mkldnn::stream> bwd_filter_stream;
std::vector<mkldnn::primitive> bwd_filter_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<MemoryArgsMap> bwd_filter_primitives_args;
#endif
ConvBwdFilterContext()
: src_fmt(memory::format::any),
:
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
diff_dst_fmt(memory::format::any),
diff_filter_fmt(memory::format::any),
#endif
src_mem(nullptr),
diff_filter_mem(nullptr),
diff_bias_mem(nullptr),
@ -201,84 +212,102 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
diff_filter_md(nullptr),
diff_bias_md(nullptr),
diff_dst_md(nullptr),
bwd_filter_stream(nullptr) {}
bwd_filter_stream(nullptr) {
}
};
// Setup Conv2d backward filter (weights) primitives.
void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
// create memory descriptors for convolution data w/ no specified format
// Create memory descriptors for convolution backward filter without any
// specific format so that MKL-DNN can pick an appropriate one depending
// on the input parameters.
context_.src_md.reset(new memory::desc(
{convBwdFilterDims.src_dims}, MklDnnType<T>(), memory::format::any));
{convBwdFilterDims.src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
context_.diff_dst_md.reset(
new memory::desc({convBwdFilterDims.diff_dst_dims}, MklDnnType<T>(),
memory::format::any));
MEMORY_FORMAT::any));
context_.diff_filter_md.reset(
new memory::desc({convBwdFilterDims.diff_filter_dims}, MklDnnType<T>(),
memory::format::any));
MEMORY_FORMAT::any));
if (!convBwdFilterDims.diff_bias_dims.empty())
context_.diff_bias_md.reset(
new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(),
memory::format::x));
MEMORY_FORMAT::x));
// create a convolution
if (!convBwdFilterDims.diff_bias_dims.empty()) {
context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
convolution_direct, *context_.src_md, *context_.diff_filter_md,
*context_.diff_bias_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
convBwdFilterDims.padding));
} else {
context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
convolution_direct, *context_.src_md, *context_.diff_filter_md,
*context_.diff_dst_md, convBwdFilterDims.strides,
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
}
// create fwd primitive_desc
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, convolution_direct, *context_.src_md,
// Create descriptor and primitive descriptor for convolution forward.
context_.fwd_desc.reset(new ConvFwdDesc(
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
convBwdFilterDims.padding));
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
// create backward conv primitive_desc
context_.bwd_filter_pd.reset(
new convolution_backward_weights::primitive_desc(
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
// Create descriptor and primitive descriptor for convolution bwd filter.
if (!convBwdFilterDims.diff_bias_dims.empty()) {
context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
ALGORITHM::convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_bias_md,
*context_.diff_dst_md, convBwdFilterDims.strides,
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
} else {
context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
ALGORITHM::convolution_direct, *context_.src_md,
*context_.diff_filter_md, *context_.diff_dst_md,
convBwdFilterDims.strides, convBwdFilterDims.dilations,
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
convBwdFilterDims.padding));
}
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
// store the expected memory format
auto bwd_filter_pd = context_.bwd_filter_pd.get();
#ifndef ENABLE_MKLDNN_V1
// Store the expected memory format.
context_.src_fmt = static_cast<mkldnn::memory::format>(
bwd_filter_pd->src_primitive_desc().desc().data.format);
context_.diff_filter_fmt = static_cast<mkldnn::memory::format>(
bwd_filter_pd->diff_weights_primitive_desc().desc().data.format);
context_.diff_dst_fmt = static_cast<mkldnn::memory::format>(
bwd_filter_pd->diff_dst_primitive_desc().desc().data.format);
#endif // !ENABLE_MKLDNN_V1
// create memory primitive based on dummy data
context_.src_mem.reset(
new memory(bwd_filter_pd->src_primitive_desc(), DummyData));
context_.diff_filter_mem.reset(
new memory(bwd_filter_pd->diff_weights_primitive_desc(), DummyData));
context_.diff_dst_mem.reset(
new memory(bwd_filter_pd->diff_dst_primitive_desc(), DummyData));
// Create memory using dummy data.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
context_.diff_filter_mem.reset(new MEMORY_CONSTRUCTOR(
bwd_filter_pd->PRIMITIVE_DESC_DIFF_WEIGHTS, cpu_engine_, DummyData));
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData));
// create convolution primitive and add it to net
// Create convolution backward filter primitive and add it to the net.
if (!convBwdFilterDims.diff_bias_dims.empty()) {
context_.diff_bias_mem.reset(
new memory({{{convBwdFilterDims.diff_bias_dims},
MklDnnType<T>(),
memory::format::x},
cpu_engine_},
DummyData));
context_.diff_bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
convBwdFilterDims.diff_bias_dims, T, MEMORY_FORMAT::x, cpu_engine_,
DummyData));
#ifdef ENABLE_MKLDNN_V1
context_.conv_bwd_filter.reset(
new convolution_backward_weights(*context_.bwd_filter_pd));
context_.bwd_filter_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem},
{MKLDNN_ARG_DIFF_BIAS, *context_.diff_bias_mem},
{ MKLDNN_ARG_DIFF_DST,
*context_.diff_dst_mem }});
} else {
context_.conv_bwd_filter.reset(
new convolution_backward_weights(*context_.bwd_filter_pd));
context_.bwd_filter_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem},
{ MKLDNN_ARG_DIFF_DST,
*context_.diff_dst_mem }});
}
#else
context_.conv_bwd_filter.reset(new convolution_backward_weights(
*context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
*context_.diff_filter_mem, *context_.diff_bias_mem));
@ -287,7 +316,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
*context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
*context_.diff_filter_mem));
}
#endif // ENABLE_MKLDNN_V1
context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter);
}
@ -305,7 +334,7 @@ class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
if (do_not_cache) { /* Create new primitive always */
conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
} else {
// look into the pool for reusable primitive
// Look into the pool for reusable primitive.
conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
convBwdFilterDims));
@ -369,23 +398,15 @@ class MklConvCustomBackpropFilterOp
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> src(&cpu_engine_);
MklDnnData<T> diff_dst(&cpu_engine_);
MklDnnData<T> diff_filter(&cpu_engine_); // output
// This flag indicates Conv2D or Conv3D
bool is_conv2d = (this->strides_.size() == 4);
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
// Input tensors.
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIdx);
MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode);
GetMklShape(context, kDiffDstIdx, &diff_dst_mkl_shape, eager_mode);
// Allow operator-specific sanity checking of shapes.
ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
@ -397,7 +418,7 @@ class MklConvCustomBackpropFilterOp
TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
TensorShape diff_dst_tf_shape =
GetTfShape(context, kOutbpropIdx, eager_mode);
GetTfShape(context, kDiffDstIdx, eager_mode);
// Corner cases: output with 0 elements and 0 batch size.
Tensor* diff_filter_tensor = nullptr;
@ -412,9 +433,9 @@ class MklConvCustomBackpropFilterOp
AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
diff_filter_tf_shape, diff_filter_mkl_shape,
eager_mode);
CHECK_NOTNULL(diff_filter_tensor);
DCHECK(diff_filter_tensor != nullptr);
// if output tensor has more than 0 elements, we need to 0 them out.
// If output tensor has more than 0 elements, we need to 0 them out.
auto diff_filter_data = diff_filter_tensor->flat<T>().data();
for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) {
diff_filter_data[i] = static_cast<T>(0);
@ -422,38 +443,44 @@ class MklConvCustomBackpropFilterOp
return;
}
// By default, all dims are in MKL order. Only dims in TF order
// are those with prefix tf_order.
// By default, all dims are in MKL order except those that are suffixed
// with `tf_order`
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
memory::dims padding_left, padding_right, dilations, strides,
fwd_dst_dims;
memory::dims fwd_dst_dims_tf_order;
memory::dims padding_left, padding_right, dilations, strides;
memory::dims fwd_dst_dims, fwd_dst_dims_tf_order;
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
this->data_format_, this->dilations_);
conv_utl.GetConvFwdSizesInMklOrder(
MklDnnConvUtil conv_util(context, this->strides_, this->padding_,
this->data_format_, this->dilations_);
conv_util.GetConvFwdSizesInMklOrder(
src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
&strides, &dilations, &fwd_dst_dims_tf_order, &fwd_dst_dims,
&padding_left, &padding_right, false, is_depthwise);
if (!context->status().ok()) return;
bool is_conv2d = (this->strides_.size() == 4);
auto tf_fmt = is_conv2d
? TFDataFormatToMklDnnDataFormat(this->data_format_)
: TFDataFormatToMklDnn3DDataFormat(this->data_format_);
#ifdef ENABLE_MKLDNN_V1
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
errors::InvalidArgument("Invalid data format"));
#endif
auto fwd_src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(fwd_src_dims, MklDnnType<T>(), tf_fmt);
: memory::desc(fwd_src_dims, MklDnnType<T>(), MKL_FMT_TAG);
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
auto diff_dst_md =
diff_dst_mkl_shape.IsMklTensor()
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(), tf_fmt);
: memory::desc(diff_dst_dims, MklDnnType<T>(), MKL_FMT_TAG);
memory::dims diff_bias_dims = {};
int64 depth = 0;
@ -464,26 +491,28 @@ class MklConvCustomBackpropFilterOp
: obp_tf_shape.dim_size(is_conv2d ? 3 : 4);
diff_bias_dims = {static_cast<int>(depth)};
}
for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
// The default dilation factor for each dimension is 1 in TF and
// 0 in MKL-DNN.
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
MklConvBwdFilterParams convBwdFilterDims(
fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
// MKL DNN allocates large buffers when a conv gradient filter primtive is
// MKL-DNN allocates large buffers when a conv gradient filter primtive is
// created. So we don't cache conv backward primitives when the env
// variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
convBwdFilterDims, do_not_cache);
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
MklConvBwdFilterPrimitive<T>* conv_bwd_filter =
MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims,
do_not_cache);
// Allocate output tensors: diff_filter and diff_bias (w bias).
auto diff_filter_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
// diff_filter
MklDnnShape diff_filter_mkl_shape;
diff_filter_mkl_shape.SetMklTensor(false);
@ -491,15 +520,15 @@ class MklConvCustomBackpropFilterOp
if (!is_depthwise) {
// Conv2D: output_dims_mkl_order is in OIHW format.
TensorShape diff_filter_tf_shape(
{bwd_output_dims[MklDnnDims::Dim_H],
bwd_output_dims[MklDnnDims::Dim_W],
bwd_output_dims[MklDnnDims::Dim_I],
bwd_output_dims[MklDnnDims::Dim_O]});
{diff_filter_dims[MklDnnDims::Dim_H],
diff_filter_dims[MklDnnDims::Dim_W],
diff_filter_dims[MklDnnDims::Dim_I],
diff_filter_dims[MklDnnDims::Dim_O]});
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
diff_filter_tf_shape, diff_filter_mkl_shape,
eager_mode);
} else {
// Depthwise Conv2d: bwd_output_dims is GOIHW format
// Depthwise Conv2d: diff_filter_dims is GOIHW format.
// | TensorFlow | MKLDNN
// ----------------------------------------------------------------
// filter_out_depth | depth_multiplier | depth_multiplier *
@ -511,10 +540,11 @@ class MklConvCustomBackpropFilterOp
// And the GOIHW is mkldnn format, here we try to extract the TF
// format, TF format is HWIO, as G = original I, so here is HWGO.
TensorShape diff_filter_tf_shape(
{bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H],
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W],
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G],
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]});
{diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H],
diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W],
diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G],
diff_filter_dims
[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]});
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
diff_filter_tf_shape,
diff_filter_mkl_shape);
@ -522,11 +552,11 @@ class MklConvCustomBackpropFilterOp
} else {
// Conv3D: output_dims_mkl_order is in OIDHW format.
TensorShape diff_filter_tf_shape(
{bwd_output_dims[MklDnnDims3D::Dim3d_D],
bwd_output_dims[MklDnnDims3D::Dim3d_H],
bwd_output_dims[MklDnnDims3D::Dim3d_W],
bwd_output_dims[MklDnnDims3D::Dim3d_I],
bwd_output_dims[MklDnnDims3D::Dim3d_O]});
{diff_filter_dims[MklDnnDims3D::Dim3d_D],
diff_filter_dims[MklDnnDims3D::Dim3d_H],
diff_filter_dims[MklDnnDims3D::Dim3d_W],
diff_filter_dims[MklDnnDims3D::Dim3d_I],
diff_filter_dims[MklDnnDims3D::Dim3d_O]});
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
diff_filter_tf_shape, diff_filter_mkl_shape);
}
@ -537,39 +567,50 @@ class MklConvCustomBackpropFilterOp
AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor);
}
// check if src and diff_dst need reorder
// Check if src and diff_dst need to be reordered.
T* src_data = nullptr;
if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) {
MklDnnData<T> src(&cpu_engine_);
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc());
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_));
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
T* diff_dst_data = nullptr;
if (diff_dst_md.data.format !=
conv_bwd_filter->GetDiffDstMemoryFormat()) {
MklDnnData<T> diff_dst(&cpu_engine_);
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd,
conv_bwd_filter)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc());
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
}
// For backward filter, convert diff_filter back to Tensorflow layout
// Here we prepare to reorder op memory back to user memory
DCHECK(!diff_filter_mkl_shape.IsMklTensor());
auto diff_filter_format = GetOutputFormat(MKL_FMT_TAG);
auto diff_filter_md =
memory::desc(diff_filter_dims, MklDnnType<T>(), diff_filter_format);
// Convert diff_filter (output) back to TF layout if needed
// (i.e. reorder op memory back to user memory)
MklDnnData<T> diff_filter(&cpu_engine_);
bool diff_filter_reorder_required = false;
T* diff_filter_data = nullptr;
if (GetOutputFormat(tf_fmt) !=
conv_bwd_filter->GetDiffFilterMemoryFormat()) {
// Allocate diff filter tensor as Tensorflow layout
diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
if (IS_DIFF_FILTER_REORDER_NEEDED(diff_filter_md, diff_filter_format,
bwd_filter_pd, conv_bwd_filter)) {
// Allocate diff_filter tensor as Tensorflow layout.
diff_filter.SetUsrMem(diff_filter_dims, diff_filter_format,
diff_filter_tensor);
diff_filter_reorder_required = true;
diff_filter.PrepareReorderToUserMemIfReq(
bwd_filter_pd->diff_weights_primitive_desc());
bwd_filter_pd->PRIMITIVE_DESC_DIFF_WEIGHTS);
diff_filter_data =
static_cast<T*>(diff_filter.GetOpMem().get_data_handle());
} else {
@ -577,7 +618,7 @@ class MklConvCustomBackpropFilterOp
const_cast<T*>(diff_filter_tensor->flat<T>().data()));
}
// Execute convolution filter bwd
// Execute convolution backward filter.
if (bias_enabled) {
T* diff_bias_data =
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
@ -587,12 +628,12 @@ class MklConvCustomBackpropFilterOp
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
}
// Reorder diff_filter back to Tensorflow layout if necessary
// Reorder diff_filter back to Tensorflow layout if necessary.
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
}
// delete primitive since it is not cached.
// Delete primitive since it is not cached.
if (do_not_cache) delete conv_bwd_filter;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
@ -605,13 +646,12 @@ class MklConvCustomBackpropFilterOp
}
private:
const int kInputIndex_Filter = 1;
const int kInputIndex_InputSizes = 0;
const int kInputIdx = 0, kFilterIdx = 1, kDiffDstIdx = 2;
const int kDilationH = 0, kDilationW = 1;
engine cpu_engine_ = engine(engine::cpu, 0);
// Validate input shapes.
// Function asserts that input shapes are valid.
engine cpu_engine_ = engine(ENGINE_CPU, 0);
// Assert that input shapes are valid.
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@ -656,20 +696,16 @@ class MklConvCustomBackpropFilterOp
// Output layout is Tensorflow's filter layout
// Conv2D: HWIO; Conv3D: DHWIO; Depthwise Conv: HWIGO
memory::format GetOutputFormat(const memory::format data_format) {
return is_depthwise
? memory::format::hwigo
: ((this->strides_.size() == 4) ? memory::format::hwio
: memory::format::dhwio);
MEMORY_FORMAT GetOutputFormat(const MEMORY_FORMAT data_format) {
return is_depthwise ? MEMORY_FORMAT::hwigo
: ((this->strides_.size() == 4) ? MEMORY_FORMAT::hwio
: MEMORY_FORMAT::dhwio);
}
// Allocate output tensor.
void AllocateOutputTensor(
OpKernelContext* context,
const convolution_backward_weights::primitive_desc& conv_pd,
const memory::dims& output_dims_mkl_order,
memory::format output_tf_format, Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
void AllocateOutputTensor(OpKernelContext* context,
const memory::dims& output_dims_mkl_order,
Tensor** output_tensor) {
DCHECK(output_tensor != nullptr);
// For BackpropFilter, we convert the output tensor back in Tensorflow
// layout. Because typically, BackpropFilter is the last operator in the
@ -689,11 +725,10 @@ class MklConvCustomBackpropFilterOp
output_mkl_shape);
}
// Allocate tensor for bias grad
void AllocateBiasGradTensor(OpKernelContext* context,
const TensorShape& bias_grad_shape,
Tensor** bias_grad_tensor) {
CHECK_NOTNULL(bias_grad_tensor);
DCHECK(bias_grad_tensor);
MklDnnShape bias_grad_mkl_shape;
bias_grad_mkl_shape.SetMklTensor(false);
@ -742,6 +777,7 @@ class MklConvCustomBackpropFilterOp
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS
#include <algorithm>
#include <vector>
@ -39,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -50,9 +52,11 @@ using mkldnn::prop_kind;
using mkldnn::stream;
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
/// utility classes enabling primitive reuse for backward conv ops.
using ConvBwdDataDesc = mkldnn::convolution_backward_data::desc;
using ConvBwdDataPd = mkldnn::convolution_backward_data::primitive_desc;
// Utility classes for enabling primitive reuse for conv bwd input.
struct MklConvBwdInputParams {
memory::dims diff_src_dims;
memory::dims filter_dims;
@ -82,20 +86,21 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
public:
explicit MklConvBwdInputPrimitive(
const MklConvBwdInputParams& convBwdInputDims)
: cpu_engine_(engine::cpu, 0) {
context_.bwd_input_stream.reset(new stream(stream::kind::eager));
: cpu_engine_(ENGINE_CPU, 0) {
context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_));
// create conv primitive
// Create conv bwd input primitive
if (context_.conv_bwd_input == nullptr) {
Setup(convBwdInputDims);
}
}
~MklConvBwdInputPrimitive() {}
// Convolution backward filter (weights)
// diff_src_data: output data buffer of diff_src
// filter_data: input data buffer of filter (weights)
// diff_dst_data: input data buffer of dst
// Convolution backward input (data) execution.
// diff_src_data: output data buffer for diff_src
// filter_data: input data buffer for filter (weights)
// diff_dst_data: input data buffer for dst
// Bias does not matter here
void Execute(const T* diff_src_data, const T* filter_data,
const T* diff_dst_data) {
@ -106,60 +111,75 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
context_.diff_dst_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_dst_data)));
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream,
context_.bwd_input_primitives_args);
#else
context_.bwd_input_stream->submit(context_.bwd_input_primitives);
#endif // ENABLE_MKLDNN_V1
// set back data handle
// Set data handle back to DummyData.
context_.diff_src_mem->set_data_handle(DummyData);
context_.filter_mem->set_data_handle(DummyData);
context_.diff_dst_mem->set_data_handle(DummyData);
return;
}
#ifndef ENABLE_MKLDNN_V1
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
memory::format GetDiffDstMemoryFormat() const {
return context_.diff_dst_fmt;
}
#endif // !ENABLE_MKLDNN_V1
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
GetPrimitiveDesc() const {
std::shared_ptr<ConvBwdDataPd> GetPrimitiveDesc() const {
return context_.bwd_input_pd;
}
private:
// Primitive reuse context for Conv Bwd Input op
// Primitive reuse context for conv bwd input.
struct ConvBwdInputContext {
// expected memory format for this primitive instance
#ifndef ENABLE_MKLDNN_V1
// Expected memory format for this primitive instance.
memory::format filter_fmt;
memory::format diff_dst_fmt;
#endif
// MKLDNN memory
// MKL-DNN memory.
std::shared_ptr<mkldnn::memory> diff_src_mem;
std::shared_ptr<mkldnn::memory> filter_mem;
std::shared_ptr<mkldnn::memory> diff_dst_mem;
// convolution primitive
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
bwd_input_pd;
// Conv backward input primitive descriptor and descriptor.
std::shared_ptr<ConvBwdDataPd> bwd_input_pd;
std::shared_ptr<ConvBwdDataDesc> bwd_input_desc;
// Primitive descriptor and descriptor for conv fwd
std::shared_ptr<ConvFwdPd> fwd_pd;
std::shared_ptr<ConvFwdDesc> fwd_desc;
// Conv bwd input primitive.
std::shared_ptr<mkldnn::primitive> conv_bwd_input;
// desc & prmitive desc
std::shared_ptr<mkldnn::convolution_backward_data::desc> bwd_input_desc;
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
// memory desc: forward & backward can share same memory::desc
// Memory descriptors: forward & backward share the same descriptors.
std::shared_ptr<memory::desc> diff_src_md;
std::shared_ptr<memory::desc> filter_md;
std::shared_ptr<memory::desc> diff_dst_md;
// MKL pipeline
// MKL-DNN pipeline for executing primitives.
std::shared_ptr<mkldnn::stream> bwd_input_stream;
std::vector<mkldnn::primitive> bwd_input_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> bwd_input_primitives_args;
#endif // ENABLE_MKLDNN_V1
ConvBwdInputContext()
: filter_fmt(memory::format::any),
:
#ifndef ENABLE_MKLDNN_V1
filter_fmt(memory::format::any),
diff_dst_fmt(memory::format::any),
#endif
diff_src_mem(nullptr),
filter_mem(nullptr),
diff_dst_mem(nullptr),
@ -171,49 +191,53 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
diff_src_md(nullptr),
filter_md(nullptr),
diff_dst_md(nullptr),
bwd_input_stream(nullptr) {}
bwd_input_stream(nullptr) {
}
};
void Setup(const MklConvBwdInputParams& convBwdInputDims) {
// create memory descriptors for convolution data w/ no specified format
context_.diff_src_md.reset(
new memory::desc({convBwdInputDims.diff_src_dims}, MklDnnType<T>(),
memory::format::any));
// Create memory descriptors for conv bwd input without any specified
// format so that MKL-DNN can pick an appropriate one depending on the
// input parameters.
context_.diff_src_md.reset(new memory::desc(
{convBwdInputDims.diff_src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
context_.filter_md.reset(new memory::desc(
{convBwdInputDims.filter_dims}, MklDnnType<T>(), memory::format::any));
context_.diff_dst_md.reset(
new memory::desc({convBwdInputDims.diff_dst_dims}, MklDnnType<T>(),
memory::format::any));
{convBwdInputDims.filter_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
context_.diff_dst_md.reset(new memory::desc(
{convBwdInputDims.diff_dst_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
// create convolution primitives
context_.bwd_input_desc.reset(new convolution_backward_data::desc(
convolution_direct, *context_.diff_src_md, *context_.filter_md,
*context_.diff_dst_md, convBwdInputDims.strides,
convBwdInputDims.dilations, convBwdInputDims.padding_left,
convBwdInputDims.padding_right, convBwdInputDims.padding));
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, convolution_direct, *context_.diff_src_md,
// Create descriptors for both conv fwd and conv bwd input.
context_.bwd_input_desc.reset(new ConvBwdDataDesc(
ALGORITHM::convolution_direct, *context_.diff_src_md,
*context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides,
convBwdInputDims.dilations, convBwdInputDims.padding_left,
convBwdInputDims.padding_right, convBwdInputDims.padding));
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
context_.fwd_desc.reset(new ConvFwdDesc(
prop_kind::forward, ALGORITHM::convolution_direct,
*context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md,
convBwdInputDims.strides, convBwdInputDims.dilations,
convBwdInputDims.padding_left, convBwdInputDims.padding_right,
convBwdInputDims.padding));
// create backward conv prim desc
context_.bwd_input_pd.reset(new convolution_backward_data::primitive_desc(
// Create primitive descriptors for conv fwd and conv bwd input.
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
context_.bwd_input_pd.reset(new ConvBwdDataPd(
*context_.bwd_input_desc, cpu_engine_, *context_.fwd_pd));
// create memory primitive based on dummy data
context_.diff_src_mem.reset(new memory(
context_.bwd_input_pd.get()->diff_src_primitive_desc(), DummyData));
context_.filter_mem.reset(new memory(
context_.bwd_input_pd.get()->weights_primitive_desc(), DummyData));
context_.diff_dst_mem.reset(new memory(
context_.bwd_input_pd.get()->diff_dst_primitive_desc(), DummyData));
// Create memory using dummy data.
context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR(
context_.bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_,
DummyData));
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
context_.bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_,
DummyData));
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
context_.bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_,
DummyData));
// store the expected memory format
#ifndef ENABLE_MKLDNN_V1
// Store the expected memory format.
context_.filter_fmt =
static_cast<memory::format>(context_.bwd_input_pd.get()
->weights_primitive_desc()
@ -224,11 +248,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
->diff_dst_primitive_desc()
.desc()
.data.format);
#endif // !ENABLE_MKLDNN_V1
// create convolution primitive and add it to net
// Create conv bwd input primitive and add it to the net
#ifdef ENABLE_MKLDNN_V1
context_.conv_bwd_input.reset(
new convolution_backward_data(*context_.bwd_input_pd));
context_.bwd_input_primitives_args.push_back(
{{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
{ MKLDNN_ARG_DIFF_SRC,
*context_.diff_src_mem }});
#else
context_.conv_bwd_input.reset(new convolution_backward_data(
*context_.bwd_input_pd, *context_.diff_dst_mem, *context_.filter_mem,
*context_.diff_src_mem));
#endif // ENABLE_MKLDNN_V1
context_.bwd_input_primitives.push_back(*context_.conv_bwd_input);
}
@ -248,10 +283,10 @@ class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) {
MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
if (do_not_cache) { /* Always allocate primitive */
if (do_not_cache) { // Always allocate primitive.
conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
} else {
// look into the pool for reusable primitive
// look into the pool for reusable primitive.
conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
convBwdInputDims));
@ -308,14 +343,7 @@ class MklConvCustomBackpropInputOp
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
// This flag indicate Conv2D or Conv3D
bool is_conv2d = (this->strides_.size() == 4);
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
// Input tensors.
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
@ -350,9 +378,9 @@ class MklConvCustomBackpropInputOp
AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor,
diff_src_tf_shape, diff_src_mkl_shape,
eager_mode);
CHECK_NOTNULL(diff_src_tensor);
DCHECK(diff_src_tensor != nullptr);
// if output tensor has more than 0 elements, we need to 0 them out.
// If output tensor has more than 0 elements, we need to 0 them out.
auto diff_src_data = diff_src_tensor->flat<T>().data();
for (size_t i = 0; i < diff_src_tf_shape.num_elements(); ++i) {
diff_src_data[i] = static_cast<T>(0);
@ -360,28 +388,36 @@ class MklConvCustomBackpropInputOp
return;
}
// By default, all dims are in MKL order. Only dims in TF order
// are those with postfix tf_order.
// By default, all dims are in MKL order except those that are suffixed
// with `tf_order`.
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
memory::dims padding_left, padding_right, dilations, strides;
memory::dims fwd_output_dims, fwd_output_dims_tf_order;
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
this->data_format_, this->dilations_);
conv_utl.GetConvFwdSizesInMklOrder(
// Get conv fwd parameters.
MklDnnConvUtil conv_util(context, this->strides_, this->padding_,
this->data_format_, this->dilations_);
conv_util.GetConvFwdSizesInMklOrder(
src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
&strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims,
&padding_left, &padding_right, false, is_depthwise);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
bool is_conv2d = (this->strides_.size() == 4);
// Create conv fwd descriptor since conv bwd input API needs it.
// For that, we first need to create input, filter and output memory
// descriptors.
auto tf_fmt = is_conv2d
? TFDataFormatToMklDnnDataFormat(this->data_format_)
: TFDataFormatToMklDnn3DDataFormat(this->data_format_);
#ifdef ENABLE_MKLDNN_V1
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
errors::InvalidArgument("Invalid data format"));
#endif // ENABLE_MKLDNN_V1
// If filter is in MKL layout, then simply grab filter layout;
// otherwise, construct filter in TF layout.
// For TF layout, filter is in HWIO format.
@ -389,42 +425,47 @@ class MklConvCustomBackpropInputOp
filter_mkl_shape.IsMklTensor()
? filter_mkl_shape.GetMklLayout()
: memory::desc(fwd_filter_dims, MklDnnType<T>(),
is_depthwise
? memory::hwigo
: (is_conv2d ? memory::format::hwio
: memory::format::dhwio));
is_depthwise ? MEMORY_FORMAT::hwigo
: (is_conv2d ? MEMORY_FORMAT::hwio
: MEMORY_FORMAT::dhwio));
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
auto diff_dst_md =
diff_dst_mkl_shape.IsMklTensor()
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims, MklDnnType<T>(), tf_fmt);
for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
: memory::desc(diff_dst_dims, MklDnnType<T>(), MKL_FMT_TAG);
// The default dilation factor for each dimension is 1 in TF and
// 0 in MKL-DNN.
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
MklConvBwdInputParams convBwdInputDims(
fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations,
padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
// We don't cache those primitves if the env variable
// We don't cache those primitives if the environment variable
// TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
// includes potentialy large buffers. MKL DNN allocates buffers
// includes potentialy large buffers. MKL-DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
// 2. 1x1 convolution with stride != 1
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
(MklPrimitiveFactory<T>::IsLegacyPlatform() ||
IsConv1x1StrideNot1(fwd_filter_dims, strides));
conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
do_not_cache);
auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
MklConvBwdInputPrimitive<T>* conv_bwd_input =
MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
do_not_cache);
auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
auto diff_src_pd = bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_SRC;
auto bwd_diff_src_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
auto bwd_diff_src_format = GetOutputFormat(tf_fmt);
// Allocate output tensor.
MklDnnShape diff_src_mkl_shape;
diff_src_mkl_shape.SetMklTensor(true);
diff_src_mkl_shape.SetMklLayout(&diff_src_pd);
@ -443,12 +484,14 @@ class MklConvCustomBackpropInputOp
T* diff_src_data =
static_cast<T*>(const_cast<T*>(diff_src_tensor->flat<T>().data()));
// check if filter and diff_dst need reorder
// Check if filter and diff_dst need to be reordered.
T* filter_data = nullptr;
if (fwd_filter_md.data.format !=
conv_bwd_input->GetFilterMemoryFormat()) {
MklDnnData<T> filter(&cpu_engine_);
if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd,
conv_bwd_input)) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
} else {
filter_data =
@ -456,16 +499,19 @@ class MklConvCustomBackpropInputOp
}
T* diff_dst_data = nullptr;
if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) {
MklDnnData<T> diff_dst(&cpu_engine_);
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd,
conv_bwd_input)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
} else {
diff_dst_data =
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
}
// execute convolution input bwd
// Execute conv bwd input primitive.
if (!eager_mode) {
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
} else {
@ -475,18 +521,20 @@ class MklConvCustomBackpropInputOp
static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
mkldnn::reorder::primitive_desc reorder_pd =
mkldnn::reorder::primitive_desc(diff_src_pd, output_tf_pd);
std::vector<mkldnn::primitive> net;
memory* tmp_data_mem = new memory(diff_src_pd, tmp_data);
memory* dst_data_mem = new memory(output_tf_pd, diff_src_data);
net.push_back(
mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem));
stream(stream::kind::eager).submit(net).wait();
#ifndef ENABLE_MKLDNN_V1
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
#endif
ReorderPd reorder_pd =
REORDER_PD_CONSTRUCTOR(diff_src_pd, OUTPUT_TF_MD, cpu_engine_);
memory* tmp_data_mem =
new MEMORY_CONSTRUCTOR(diff_src_pd, cpu_engine_, tmp_data);
memory* dst_data_mem =
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data);
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
cpu_engine_);
}
// delete primitive since it is not cached.
// Delete primitive since it is not cached.
if (do_not_cache) {
delete conv_bwd_input;
}
@ -501,12 +549,12 @@ class MklConvCustomBackpropInputOp
}
private:
const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0;
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const int kDilationH = 0, kDilationW = 1;
engine cpu_engine = engine(engine::cpu, 0);
// Validate input shapes.
// Function asserts that input shapes are valid.
engine cpu_engine_ = engine(ENGINE_CPU, 0);
// Assert that input shapes are valid.
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
@ -532,7 +580,7 @@ class MklConvCustomBackpropInputOp
// Get TensorFlow shape of filter tensor.
TensorShape MakeFilterTfShape(OpKernelContext* context,
const Tensor& filter_tensor) {
return GetTfShape(context, kInputIndex_Filter, eager_mode);
return GetTfShape(context, kFilterIdx, eager_mode);
}
// Get the Tensorflow shape of Output (diff_src),
@ -543,30 +591,29 @@ class MklConvCustomBackpropInputOp
return input_shape;
}
// Get the Tensorflow shape of Output (diff_src),
// which is same as shape of Conv 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
return fwd_input_dims;
}
// Output layout is Tensorflow's layout in data format order.
memory::format GetOutputFormat(const memory::format data_format) {
MKL_TENSOR_FORMAT GetOutputFormat(const MKL_TENSOR_FORMAT data_format) {
return data_format;
}
// Allocate output tensor.
void AllocateOutputTensor(
OpKernelContext* context,
const convolution_backward_data::primitive_desc& conv_pd,
const memory::dims& output_dims_mkl_order,
memory::format output_tf_format, Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
// TODO(bhavanis): Move this function to mkl_util.h since it is common to
// both the forward and backward implementations
void AllocateOutputTensor(OpKernelContext* context,
const ConvBwdDataPd& conv_pd,
const memory::dims& output_dims_mkl_order,
MKL_TENSOR_FORMAT output_tf_format,
Tensor** output_tensor) {
DCHECK(output_tensor != nullptr);
// Output primitive descriptor for backward data is diff_src.
auto dst_pd = conv_pd.diff_src_primitive_desc();
auto dst_pd = conv_pd.PRIMITIVE_DESC_DIFF_SRC;
// Allocate shape of Mkl tensor.
// Allocate shape of MKL tensor.
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
@ -608,8 +655,10 @@ class MklConvCustomBackpropInputOp
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklConvCustomBackpropInputOp<CPUDevice, T, true, false>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
TF_CALL_bfloat16(REGISTER_MKL_CPU_KERNELS);
#undef REGISTER_MKL_CPU_KERNELS
} // namespace tensorflow

View File

@ -55,6 +55,7 @@ namespace tensorflow {
#define MKLDNN_SIZE_DTYPE int
#endif // ENABLE_MKLDNN_V1
using ConvFwdDesc = mkldnn::convolution_forward::desc;
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
class MklDnnConvUtil {