Merge pull request #36630 from Intel-tensorflow:dnn10_conv_bwd
PiperOrigin-RevId: 295228960 Change-Id: I2bc99b61cd0a8dc9b9771a0e65e393fd6176d5d9
This commit is contained in:
commit
81323b7924
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -46,8 +47,12 @@ using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
using ConvBwdFilterDesc = mkldnn::convolution_backward_weights::desc;
|
||||
using ConvBwdFilterPd = mkldnn::convolution_backward_weights::primitive_desc;
|
||||
|
||||
struct MklConvBwdFilterParams {
|
||||
memory::dims src_dims;
|
||||
memory::dims diff_filter_dims;
|
||||
@ -80,9 +85,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvBwdFilterPrimitive(
|
||||
const MklConvBwdFilterParams& convBwdFilterDims)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
|
||||
// create conv primitive
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
// Create convolution backward filter primitive.
|
||||
if (context_.conv_bwd_filter == nullptr) {
|
||||
Setup(convBwdFilterDims);
|
||||
}
|
||||
@ -90,106 +96,111 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
|
||||
~MklConvBwdFilterPrimitive() {}
|
||||
|
||||
// Convolution backward weights with bias
|
||||
// src_data: input data buffer of src
|
||||
// diff_filter_data: output data buffer of diff_filter
|
||||
// diff_bias_data: output data buffer of diff_bias
|
||||
// diff_dst_data: input data buffer of diff_dst
|
||||
// Convolution backward weights execution with bias
|
||||
// src_data: input data buffer for src
|
||||
// diff_filter_data: output data buffer for diff_filter
|
||||
// diff_bias_data: output data buffer for diff_bias
|
||||
// diff_dst_data: input data buffer for diff_dst
|
||||
void Execute(const T* src_data, const T* diff_filter_data,
|
||||
const T* diff_bias_data, const T* diff_dst_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.diff_filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_filter_data)));
|
||||
context_.diff_bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_bias_data)));
|
||||
if (diff_bias_data != nullptr) {
|
||||
context_.diff_bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_bias_data)));
|
||||
}
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_filter_primitives,
|
||||
context_.bwd_filter_stream,
|
||||
context_.bwd_filter_primitives_args);
|
||||
#else
|
||||
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
|
||||
#endif
|
||||
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.diff_filter_mem->set_data_handle(DummyData);
|
||||
context_.diff_bias_mem->set_data_handle(DummyData);
|
||||
if (diff_bias_data != nullptr) {
|
||||
context_.diff_bias_mem->set_data_handle(DummyData);
|
||||
}
|
||||
context_.diff_dst_mem->set_data_handle(DummyData);
|
||||
return;
|
||||
}
|
||||
|
||||
// Convolution backward weights without bias
|
||||
// Convolution backward weights without bias.
|
||||
// src_data: input data buffer of src
|
||||
// diff_filter_data: output data buffer of diff_filter
|
||||
// diff_dst_data: input data buffer of diff_dst
|
||||
void Execute(const T* src_data, const T* diff_filter_data,
|
||||
const T* diff_dst_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.diff_filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_filter_data)));
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
context_.bwd_filter_stream->submit(context_.bwd_filter_primitives);
|
||||
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.diff_filter_mem->set_data_handle(DummyData);
|
||||
context_.diff_dst_mem->set_data_handle(DummyData);
|
||||
return;
|
||||
Execute(src_data, diff_filter_data, nullptr, diff_dst_data);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
|
||||
memory::format GetDiffDstMemoryFormat() const {
|
||||
return context_.diff_dst_fmt;
|
||||
}
|
||||
|
||||
memory::format GetDiffFilterMemoryFormat() const {
|
||||
return context_.diff_filter_fmt;
|
||||
}
|
||||
#endif
|
||||
|
||||
// convolution primitive
|
||||
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
|
||||
GetPrimitiveDesc() const {
|
||||
std::shared_ptr<ConvBwdFilterPd> GetPrimitiveDesc() const {
|
||||
return context_.bwd_filter_pd;
|
||||
}
|
||||
|
||||
private:
|
||||
// Primitive reuse context for Conv2D bwd filter op
|
||||
// Primitive reuse context for Conv2D backward filter op.
|
||||
struct ConvBwdFilterContext {
|
||||
// expected memory format for this primitive instance
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Expected memory format for this primitive instance
|
||||
memory::format src_fmt;
|
||||
memory::format diff_dst_fmt;
|
||||
memory::format diff_filter_fmt;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// convolution bwd input primitive
|
||||
std::shared_ptr<mkldnn::convolution_backward_weights::primitive_desc>
|
||||
bwd_filter_pd;
|
||||
std::shared_ptr<mkldnn::primitive> conv_bwd_filter;
|
||||
|
||||
// MKLDNN memory
|
||||
// MKL-DNN memory for inputs and outputs.
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> diff_filter_mem;
|
||||
std::shared_ptr<mkldnn::memory> diff_bias_mem;
|
||||
std::shared_ptr<mkldnn::memory> diff_dst_mem;
|
||||
|
||||
// desc & prmitive desc
|
||||
std::shared_ptr<mkldnn::convolution_backward_weights::desc> bwd_filter_desc;
|
||||
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
|
||||
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
|
||||
// Primitive descriptor and descriptor for convolution backward filter.
|
||||
std::shared_ptr<ConvBwdFilterPd> bwd_filter_pd;
|
||||
std::shared_ptr<ConvBwdFilterDesc> bwd_filter_desc;
|
||||
|
||||
// memory desc: forward & backward can share same memory desc
|
||||
// Primitive descriptor and descriptor for convolution forward.
|
||||
std::shared_ptr<ConvFwdPd> fwd_pd;
|
||||
std::shared_ptr<ConvFwdDesc> fwd_desc;
|
||||
|
||||
// Convolution backward filter primitive.
|
||||
std::shared_ptr<mkldnn::primitive> conv_bwd_filter;
|
||||
|
||||
// Memory descriptors: forward & backward share the same memory descriptors
|
||||
std::shared_ptr<mkldnn::memory::desc> src_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> diff_filter_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> diff_bias_md;
|
||||
std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
|
||||
|
||||
// MKL pipeline
|
||||
// MKL-DNN pipeline for executing primitives.
|
||||
std::shared_ptr<mkldnn::stream> bwd_filter_stream;
|
||||
std::vector<mkldnn::primitive> bwd_filter_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<MemoryArgsMap> bwd_filter_primitives_args;
|
||||
#endif
|
||||
|
||||
ConvBwdFilterContext()
|
||||
: src_fmt(memory::format::any),
|
||||
:
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_fmt(memory::format::any),
|
||||
diff_dst_fmt(memory::format::any),
|
||||
diff_filter_fmt(memory::format::any),
|
||||
#endif
|
||||
src_mem(nullptr),
|
||||
diff_filter_mem(nullptr),
|
||||
diff_bias_mem(nullptr),
|
||||
@ -201,84 +212,102 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
diff_filter_md(nullptr),
|
||||
diff_bias_md(nullptr),
|
||||
diff_dst_md(nullptr),
|
||||
bwd_filter_stream(nullptr) {}
|
||||
bwd_filter_stream(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
// Setup Conv2d backward filter (weights) primitives.
|
||||
void Setup(const MklConvBwdFilterParams& convBwdFilterDims) {
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
// Create memory descriptors for convolution backward filter without any
|
||||
// specific format so that MKL-DNN can pick an appropriate one depending
|
||||
// on the input parameters.
|
||||
context_.src_md.reset(new memory::desc(
|
||||
{convBwdFilterDims.src_dims}, MklDnnType<T>(), memory::format::any));
|
||||
{convBwdFilterDims.src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
|
||||
|
||||
context_.diff_dst_md.reset(
|
||||
new memory::desc({convBwdFilterDims.diff_dst_dims}, MklDnnType<T>(),
|
||||
memory::format::any));
|
||||
MEMORY_FORMAT::any));
|
||||
|
||||
context_.diff_filter_md.reset(
|
||||
new memory::desc({convBwdFilterDims.diff_filter_dims}, MklDnnType<T>(),
|
||||
memory::format::any));
|
||||
MEMORY_FORMAT::any));
|
||||
|
||||
if (!convBwdFilterDims.diff_bias_dims.empty())
|
||||
context_.diff_bias_md.reset(
|
||||
new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType<T>(),
|
||||
memory::format::x));
|
||||
MEMORY_FORMAT::x));
|
||||
|
||||
// create a convolution
|
||||
if (!convBwdFilterDims.diff_bias_dims.empty()) {
|
||||
context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
|
||||
convolution_direct, *context_.src_md, *context_.diff_filter_md,
|
||||
*context_.diff_bias_md, *context_.diff_dst_md,
|
||||
convBwdFilterDims.strides, convBwdFilterDims.dilations,
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
|
||||
convBwdFilterDims.padding));
|
||||
} else {
|
||||
context_.bwd_filter_desc.reset(new convolution_backward_weights::desc(
|
||||
convolution_direct, *context_.src_md, *context_.diff_filter_md,
|
||||
*context_.diff_dst_md, convBwdFilterDims.strides,
|
||||
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
|
||||
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
|
||||
}
|
||||
|
||||
// create fwd primitive_desc
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, convolution_direct, *context_.src_md,
|
||||
// Create descriptor and primitive descriptor for convolution forward.
|
||||
context_.fwd_desc.reset(new ConvFwdDesc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.diff_filter_md, *context_.diff_dst_md,
|
||||
convBwdFilterDims.strides, convBwdFilterDims.dilations,
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
|
||||
convBwdFilterDims.padding));
|
||||
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
|
||||
// create backward conv primitive_desc
|
||||
context_.bwd_filter_pd.reset(
|
||||
new convolution_backward_weights::primitive_desc(
|
||||
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
|
||||
// Create descriptor and primitive descriptor for convolution bwd filter.
|
||||
if (!convBwdFilterDims.diff_bias_dims.empty()) {
|
||||
context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
|
||||
ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.diff_filter_md, *context_.diff_bias_md,
|
||||
*context_.diff_dst_md, convBwdFilterDims.strides,
|
||||
convBwdFilterDims.dilations, convBwdFilterDims.padding_left,
|
||||
convBwdFilterDims.padding_right, convBwdFilterDims.padding));
|
||||
} else {
|
||||
context_.bwd_filter_desc.reset(new ConvBwdFilterDesc(
|
||||
ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.diff_filter_md, *context_.diff_dst_md,
|
||||
convBwdFilterDims.strides, convBwdFilterDims.dilations,
|
||||
convBwdFilterDims.padding_left, convBwdFilterDims.padding_right,
|
||||
convBwdFilterDims.padding));
|
||||
}
|
||||
context_.bwd_filter_pd.reset(new ConvBwdFilterPd(
|
||||
*context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd));
|
||||
|
||||
// store the expected memory format
|
||||
auto bwd_filter_pd = context_.bwd_filter_pd.get();
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Store the expected memory format.
|
||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
bwd_filter_pd->src_primitive_desc().desc().data.format);
|
||||
context_.diff_filter_fmt = static_cast<mkldnn::memory::format>(
|
||||
bwd_filter_pd->diff_weights_primitive_desc().desc().data.format);
|
||||
context_.diff_dst_fmt = static_cast<mkldnn::memory::format>(
|
||||
bwd_filter_pd->diff_dst_primitive_desc().desc().data.format);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// create memory primitive based on dummy data
|
||||
context_.src_mem.reset(
|
||||
new memory(bwd_filter_pd->src_primitive_desc(), DummyData));
|
||||
context_.diff_filter_mem.reset(
|
||||
new memory(bwd_filter_pd->diff_weights_primitive_desc(), DummyData));
|
||||
context_.diff_dst_mem.reset(
|
||||
new memory(bwd_filter_pd->diff_dst_primitive_desc(), DummyData));
|
||||
// Create memory using dummy data.
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
|
||||
context_.diff_filter_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_DIFF_WEIGHTS, cpu_engine_, DummyData));
|
||||
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData));
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
// Create convolution backward filter primitive and add it to the net.
|
||||
if (!convBwdFilterDims.diff_bias_dims.empty()) {
|
||||
context_.diff_bias_mem.reset(
|
||||
new memory({{{convBwdFilterDims.diff_bias_dims},
|
||||
MklDnnType<T>(),
|
||||
memory::format::x},
|
||||
cpu_engine_},
|
||||
DummyData));
|
||||
context_.diff_bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
|
||||
convBwdFilterDims.diff_bias_dims, T, MEMORY_FORMAT::x, cpu_engine_,
|
||||
DummyData));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.conv_bwd_filter.reset(
|
||||
new convolution_backward_weights(*context_.bwd_filter_pd));
|
||||
context_.bwd_filter_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem},
|
||||
{MKLDNN_ARG_DIFF_BIAS, *context_.diff_bias_mem},
|
||||
{ MKLDNN_ARG_DIFF_DST,
|
||||
*context_.diff_dst_mem }});
|
||||
} else {
|
||||
context_.conv_bwd_filter.reset(
|
||||
new convolution_backward_weights(*context_.bwd_filter_pd));
|
||||
context_.bwd_filter_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem},
|
||||
{ MKLDNN_ARG_DIFF_DST,
|
||||
*context_.diff_dst_mem }});
|
||||
}
|
||||
#else
|
||||
context_.conv_bwd_filter.reset(new convolution_backward_weights(
|
||||
*context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
|
||||
*context_.diff_filter_mem, *context_.diff_bias_mem));
|
||||
@ -287,7 +316,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
*context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem,
|
||||
*context_.diff_filter_mem));
|
||||
}
|
||||
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter);
|
||||
}
|
||||
|
||||
@ -305,7 +334,7 @@ class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
if (do_not_cache) { /* Create new primitive always */
|
||||
conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
|
||||
} else {
|
||||
// look into the pool for reusable primitive
|
||||
// Look into the pool for reusable primitive.
|
||||
conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>(
|
||||
MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
|
||||
convBwdFilterDims));
|
||||
@ -369,23 +398,15 @@ class MklConvCustomBackpropFilterOp
|
||||
|
||||
void Compute(OpKernelContext* context) {
|
||||
try {
|
||||
MklDnnData<T> src(&cpu_engine_);
|
||||
MklDnnData<T> diff_dst(&cpu_engine_);
|
||||
MklDnnData<T> diff_filter(&cpu_engine_); // output
|
||||
|
||||
// This flag indicates Conv2D or Conv3D
|
||||
bool is_conv2d = (this->strides_.size() == 4);
|
||||
|
||||
// Input tensors
|
||||
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
|
||||
// Input tensors.
|
||||
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
|
||||
const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
|
||||
const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
|
||||
const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIdx);
|
||||
|
||||
MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape;
|
||||
GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode);
|
||||
GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode);
|
||||
GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode);
|
||||
GetMklShape(context, kDiffDstIdx, &diff_dst_mkl_shape, eager_mode);
|
||||
// Allow operator-specific sanity checking of shapes.
|
||||
ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape);
|
||||
|
||||
@ -397,7 +418,7 @@ class MklConvCustomBackpropFilterOp
|
||||
TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor);
|
||||
TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
|
||||
TensorShape diff_dst_tf_shape =
|
||||
GetTfShape(context, kOutbpropIdx, eager_mode);
|
||||
GetTfShape(context, kDiffDstIdx, eager_mode);
|
||||
|
||||
// Corner cases: output with 0 elements and 0 batch size.
|
||||
Tensor* diff_filter_tensor = nullptr;
|
||||
@ -412,9 +433,9 @@ class MklConvCustomBackpropFilterOp
|
||||
AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor,
|
||||
diff_filter_tf_shape, diff_filter_mkl_shape,
|
||||
eager_mode);
|
||||
CHECK_NOTNULL(diff_filter_tensor);
|
||||
DCHECK(diff_filter_tensor != nullptr);
|
||||
|
||||
// if output tensor has more than 0 elements, we need to 0 them out.
|
||||
// If output tensor has more than 0 elements, we need to 0 them out.
|
||||
auto diff_filter_data = diff_filter_tensor->flat<T>().data();
|
||||
for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) {
|
||||
diff_filter_data[i] = static_cast<T>(0);
|
||||
@ -422,38 +443,44 @@ class MklConvCustomBackpropFilterOp
|
||||
return;
|
||||
}
|
||||
|
||||
// By default, all dims are in MKL order. Only dims in TF order
|
||||
// are those with prefix tf_order.
|
||||
// By default, all dims are in MKL order except those that are suffixed
|
||||
// with `tf_order`
|
||||
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
|
||||
memory::dims padding_left, padding_right, dilations, strides,
|
||||
fwd_dst_dims;
|
||||
memory::dims fwd_dst_dims_tf_order;
|
||||
memory::dims padding_left, padding_right, dilations, strides;
|
||||
memory::dims fwd_dst_dims, fwd_dst_dims_tf_order;
|
||||
|
||||
// Get forward convolution parameters.
|
||||
MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
|
||||
this->data_format_, this->dilations_);
|
||||
conv_utl.GetConvFwdSizesInMklOrder(
|
||||
MklDnnConvUtil conv_util(context, this->strides_, this->padding_,
|
||||
this->data_format_, this->dilations_);
|
||||
conv_util.GetConvFwdSizesInMklOrder(
|
||||
src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
|
||||
&strides, &dilations, &fwd_dst_dims_tf_order, &fwd_dst_dims,
|
||||
&padding_left, &padding_right, false, is_depthwise);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
bool is_conv2d = (this->strides_.size() == 4);
|
||||
|
||||
auto tf_fmt = is_conv2d
|
||||
? TFDataFormatToMklDnnDataFormat(this->data_format_)
|
||||
: TFDataFormatToMklDnn3DDataFormat(this->data_format_);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
||||
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
#endif
|
||||
|
||||
auto fwd_src_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(fwd_src_dims, MklDnnType<T>(), tf_fmt);
|
||||
: memory::desc(fwd_src_dims, MklDnnType<T>(), MKL_FMT_TAG);
|
||||
|
||||
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
|
||||
conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
auto diff_dst_md =
|
||||
diff_dst_mkl_shape.IsMklTensor()
|
||||
? diff_dst_mkl_shape.GetMklLayout()
|
||||
: memory::desc(diff_dst_dims, MklDnnType<T>(), tf_fmt);
|
||||
: memory::desc(diff_dst_dims, MklDnnType<T>(), MKL_FMT_TAG);
|
||||
|
||||
memory::dims diff_bias_dims = {};
|
||||
int64 depth = 0;
|
||||
@ -464,26 +491,28 @@ class MklConvCustomBackpropFilterOp
|
||||
: obp_tf_shape.dim_size(is_conv2d ? 3 : 4);
|
||||
diff_bias_dims = {static_cast<int>(depth)};
|
||||
}
|
||||
for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
|
||||
|
||||
MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
|
||||
// The default dilation factor for each dimension is 1 in TF and
|
||||
// 0 in MKL-DNN.
|
||||
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
|
||||
|
||||
MklConvBwdFilterParams convBwdFilterDims(
|
||||
fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides,
|
||||
dilations, padding_left, padding_right,
|
||||
TFPaddingToMklDnnPadding(this->padding_));
|
||||
|
||||
// MKL DNN allocates large buffers when a conv gradient filter primtive is
|
||||
// MKL-DNN allocates large buffers when a conv gradient filter primtive is
|
||||
// created. So we don't cache conv backward primitives when the env
|
||||
// variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
|
||||
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
|
||||
conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
|
||||
convBwdFilterDims, do_not_cache);
|
||||
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
|
||||
|
||||
// allocate output tensors: diff_fitler and diff_bias (w bias)
|
||||
auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
|
||||
MklConvBwdFilterPrimitive<T>* conv_bwd_filter =
|
||||
MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims,
|
||||
do_not_cache);
|
||||
|
||||
// Allocate output tensors: diff_filter and diff_bias (w bias).
|
||||
auto diff_filter_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
|
||||
|
||||
// diff_filter
|
||||
MklDnnShape diff_filter_mkl_shape;
|
||||
diff_filter_mkl_shape.SetMklTensor(false);
|
||||
|
||||
@ -491,15 +520,15 @@ class MklConvCustomBackpropFilterOp
|
||||
if (!is_depthwise) {
|
||||
// Conv2D: output_dims_mkl_order is in OIHW format.
|
||||
TensorShape diff_filter_tf_shape(
|
||||
{bwd_output_dims[MklDnnDims::Dim_H],
|
||||
bwd_output_dims[MklDnnDims::Dim_W],
|
||||
bwd_output_dims[MklDnnDims::Dim_I],
|
||||
bwd_output_dims[MklDnnDims::Dim_O]});
|
||||
{diff_filter_dims[MklDnnDims::Dim_H],
|
||||
diff_filter_dims[MklDnnDims::Dim_W],
|
||||
diff_filter_dims[MklDnnDims::Dim_I],
|
||||
diff_filter_dims[MklDnnDims::Dim_O]});
|
||||
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
|
||||
diff_filter_tf_shape, diff_filter_mkl_shape,
|
||||
eager_mode);
|
||||
} else {
|
||||
// Depthwise Conv2d: bwd_output_dims is GOIHW format
|
||||
// Depthwise Conv2d: diff_filter_dims is GOIHW format.
|
||||
// | TensorFlow | MKLDNN
|
||||
// ----------------------------------------------------------------
|
||||
// filter_out_depth | depth_multiplier | depth_multiplier *
|
||||
@ -511,10 +540,11 @@ class MklConvCustomBackpropFilterOp
|
||||
// And the GOIHW is mkldnn format, here we try to extract the TF
|
||||
// format, TF format is HWIO, as G = original I, so here is HWGO.
|
||||
TensorShape diff_filter_tf_shape(
|
||||
{bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H],
|
||||
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W],
|
||||
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G],
|
||||
bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]});
|
||||
{diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H],
|
||||
diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W],
|
||||
diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G],
|
||||
diff_filter_dims
|
||||
[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]});
|
||||
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
|
||||
diff_filter_tf_shape,
|
||||
diff_filter_mkl_shape);
|
||||
@ -522,11 +552,11 @@ class MklConvCustomBackpropFilterOp
|
||||
} else {
|
||||
// Conv3D: output_dims_mkl_order is in OIDHW format.
|
||||
TensorShape diff_filter_tf_shape(
|
||||
{bwd_output_dims[MklDnnDims3D::Dim3d_D],
|
||||
bwd_output_dims[MklDnnDims3D::Dim3d_H],
|
||||
bwd_output_dims[MklDnnDims3D::Dim3d_W],
|
||||
bwd_output_dims[MklDnnDims3D::Dim3d_I],
|
||||
bwd_output_dims[MklDnnDims3D::Dim3d_O]});
|
||||
{diff_filter_dims[MklDnnDims3D::Dim3d_D],
|
||||
diff_filter_dims[MklDnnDims3D::Dim3d_H],
|
||||
diff_filter_dims[MklDnnDims3D::Dim3d_W],
|
||||
diff_filter_dims[MklDnnDims3D::Dim3d_I],
|
||||
diff_filter_dims[MklDnnDims3D::Dim3d_O]});
|
||||
AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
|
||||
diff_filter_tf_shape, diff_filter_mkl_shape);
|
||||
}
|
||||
@ -537,39 +567,50 @@ class MklConvCustomBackpropFilterOp
|
||||
AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor);
|
||||
}
|
||||
|
||||
// check if src and diff_dst need reorder
|
||||
// Check if src and diff_dst need to be reordered.
|
||||
T* src_data = nullptr;
|
||||
if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) {
|
||||
MklDnnData<T> src(&cpu_engine_);
|
||||
auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
|
||||
if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) {
|
||||
src.SetUsrMem(fwd_src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc());
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
T* diff_dst_data = nullptr;
|
||||
if (diff_dst_md.data.format !=
|
||||
conv_bwd_filter->GetDiffDstMemoryFormat()) {
|
||||
MklDnnData<T> diff_dst(&cpu_engine_);
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd,
|
||||
conv_bwd_filter)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc());
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
// For backward filter, convert diff_filter back to Tensorflow layout
|
||||
// Here we prepare to reorder op memory back to user memory
|
||||
DCHECK(!diff_filter_mkl_shape.IsMklTensor());
|
||||
auto diff_filter_format = GetOutputFormat(MKL_FMT_TAG);
|
||||
auto diff_filter_md =
|
||||
memory::desc(diff_filter_dims, MklDnnType<T>(), diff_filter_format);
|
||||
|
||||
// Convert diff_filter (output) back to TF layout if needed
|
||||
// (i.e. reorder op memory back to user memory)
|
||||
MklDnnData<T> diff_filter(&cpu_engine_);
|
||||
bool diff_filter_reorder_required = false;
|
||||
T* diff_filter_data = nullptr;
|
||||
if (GetOutputFormat(tf_fmt) !=
|
||||
conv_bwd_filter->GetDiffFilterMemoryFormat()) {
|
||||
// Allocate diff filter tensor as Tensorflow layout
|
||||
diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
|
||||
if (IS_DIFF_FILTER_REORDER_NEEDED(diff_filter_md, diff_filter_format,
|
||||
bwd_filter_pd, conv_bwd_filter)) {
|
||||
// Allocate diff_filter tensor as Tensorflow layout.
|
||||
diff_filter.SetUsrMem(diff_filter_dims, diff_filter_format,
|
||||
diff_filter_tensor);
|
||||
diff_filter_reorder_required = true;
|
||||
diff_filter.PrepareReorderToUserMemIfReq(
|
||||
bwd_filter_pd->diff_weights_primitive_desc());
|
||||
bwd_filter_pd->PRIMITIVE_DESC_DIFF_WEIGHTS);
|
||||
diff_filter_data =
|
||||
static_cast<T*>(diff_filter.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
@ -577,7 +618,7 @@ class MklConvCustomBackpropFilterOp
|
||||
const_cast<T*>(diff_filter_tensor->flat<T>().data()));
|
||||
}
|
||||
|
||||
// Execute convolution filter bwd
|
||||
// Execute convolution backward filter.
|
||||
if (bias_enabled) {
|
||||
T* diff_bias_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_bias_tensor->flat<T>().data()));
|
||||
@ -587,12 +628,12 @@ class MklConvCustomBackpropFilterOp
|
||||
conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
|
||||
}
|
||||
|
||||
// Reorder diff_filter back to Tensorflow layout if necessary
|
||||
// Reorder diff_filter back to Tensorflow layout if necessary.
|
||||
if (diff_filter_reorder_required) {
|
||||
diff_filter.InsertReorderToUserMem();
|
||||
}
|
||||
|
||||
// delete primitive since it is not cached.
|
||||
// Delete primitive since it is not cached.
|
||||
if (do_not_cache) delete conv_bwd_filter;
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
@ -605,13 +646,12 @@ class MklConvCustomBackpropFilterOp
|
||||
}
|
||||
|
||||
private:
|
||||
const int kInputIndex_Filter = 1;
|
||||
const int kInputIndex_InputSizes = 0;
|
||||
const int kInputIdx = 0, kFilterIdx = 1, kDiffDstIdx = 2;
|
||||
const int kDilationH = 0, kDilationW = 1;
|
||||
engine cpu_engine_ = engine(engine::cpu, 0);
|
||||
|
||||
// Validate input shapes.
|
||||
// Function asserts that input shapes are valid.
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
|
||||
// Assert that input shapes are valid.
|
||||
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
|
||||
const MklDnnShape& filter_mkl_shape,
|
||||
const MklDnnShape& obp_mkl_shape) {
|
||||
@ -656,20 +696,16 @@ class MklConvCustomBackpropFilterOp
|
||||
|
||||
// Output layout is Tensorflow's filter layout
|
||||
// Conv2D: HWIO; Conv3D: DHWIO; Depthwise Conv: HWIGO
|
||||
memory::format GetOutputFormat(const memory::format data_format) {
|
||||
return is_depthwise
|
||||
? memory::format::hwigo
|
||||
: ((this->strides_.size() == 4) ? memory::format::hwio
|
||||
: memory::format::dhwio);
|
||||
MEMORY_FORMAT GetOutputFormat(const MEMORY_FORMAT data_format) {
|
||||
return is_depthwise ? MEMORY_FORMAT::hwigo
|
||||
: ((this->strides_.size() == 4) ? MEMORY_FORMAT::hwio
|
||||
: MEMORY_FORMAT::dhwio);
|
||||
}
|
||||
|
||||
// Allocate output tensor.
|
||||
void AllocateOutputTensor(
|
||||
OpKernelContext* context,
|
||||
const convolution_backward_weights::primitive_desc& conv_pd,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
memory::format output_tf_format, Tensor** output_tensor) {
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
void AllocateOutputTensor(OpKernelContext* context,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
Tensor** output_tensor) {
|
||||
DCHECK(output_tensor != nullptr);
|
||||
|
||||
// For BackpropFilter, we convert the output tensor back in Tensorflow
|
||||
// layout. Because typically, BackpropFilter is the last operator in the
|
||||
@ -689,11 +725,10 @@ class MklConvCustomBackpropFilterOp
|
||||
output_mkl_shape);
|
||||
}
|
||||
|
||||
// Allocate tensor for bias grad
|
||||
void AllocateBiasGradTensor(OpKernelContext* context,
|
||||
const TensorShape& bias_grad_shape,
|
||||
Tensor** bias_grad_tensor) {
|
||||
CHECK_NOTNULL(bias_grad_tensor);
|
||||
DCHECK(bias_grad_tensor);
|
||||
|
||||
MklDnnShape bias_grad_mkl_shape;
|
||||
bias_grad_mkl_shape.SetMklTensor(false);
|
||||
@ -742,6 +777,7 @@ class MklConvCustomBackpropFilterOp
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS);
|
||||
|
||||
#undef REGISTER_MKL_FILTER_KERNELS
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#define USE_EIGEN_TENSOR
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
@ -39,6 +40,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -50,9 +52,11 @@ using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
/// utility classes enabling primitive reuse for backward conv ops.
|
||||
using ConvBwdDataDesc = mkldnn::convolution_backward_data::desc;
|
||||
using ConvBwdDataPd = mkldnn::convolution_backward_data::primitive_desc;
|
||||
|
||||
// Utility classes for enabling primitive reuse for conv bwd input.
|
||||
struct MklConvBwdInputParams {
|
||||
memory::dims diff_src_dims;
|
||||
memory::dims filter_dims;
|
||||
@ -82,20 +86,21 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvBwdInputPrimitive(
|
||||
const MklConvBwdInputParams& convBwdInputDims)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
context_.bwd_input_stream.reset(new stream(stream::kind::eager));
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
// create conv primitive
|
||||
// Create conv bwd input primitive
|
||||
if (context_.conv_bwd_input == nullptr) {
|
||||
Setup(convBwdInputDims);
|
||||
}
|
||||
}
|
||||
|
||||
~MklConvBwdInputPrimitive() {}
|
||||
|
||||
// Convolution backward filter (weights)
|
||||
// diff_src_data: output data buffer of diff_src
|
||||
// filter_data: input data buffer of filter (weights)
|
||||
// diff_dst_data: input data buffer of dst
|
||||
// Convolution backward input (data) execution.
|
||||
// diff_src_data: output data buffer for diff_src
|
||||
// filter_data: input data buffer for filter (weights)
|
||||
// diff_dst_data: input data buffer for dst
|
||||
// Bias does not matter here
|
||||
void Execute(const T* diff_src_data, const T* filter_data,
|
||||
const T* diff_dst_data) {
|
||||
@ -106,60 +111,75 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream,
|
||||
context_.bwd_input_primitives_args);
|
||||
#else
|
||||
context_.bwd_input_stream->submit(context_.bwd_input_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// set back data handle
|
||||
// Set data handle back to DummyData.
|
||||
context_.diff_src_mem->set_data_handle(DummyData);
|
||||
context_.filter_mem->set_data_handle(DummyData);
|
||||
context_.diff_dst_mem->set_data_handle(DummyData);
|
||||
return;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
|
||||
|
||||
memory::format GetDiffDstMemoryFormat() const {
|
||||
return context_.diff_dst_fmt;
|
||||
}
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
|
||||
GetPrimitiveDesc() const {
|
||||
std::shared_ptr<ConvBwdDataPd> GetPrimitiveDesc() const {
|
||||
return context_.bwd_input_pd;
|
||||
}
|
||||
|
||||
private:
|
||||
// Primitive reuse context for Conv Bwd Input op
|
||||
// Primitive reuse context for conv bwd input.
|
||||
struct ConvBwdInputContext {
|
||||
// expected memory format for this primitive instance
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Expected memory format for this primitive instance.
|
||||
memory::format filter_fmt;
|
||||
memory::format diff_dst_fmt;
|
||||
#endif
|
||||
|
||||
// MKLDNN memory
|
||||
// MKL-DNN memory.
|
||||
std::shared_ptr<mkldnn::memory> diff_src_mem;
|
||||
std::shared_ptr<mkldnn::memory> filter_mem;
|
||||
std::shared_ptr<mkldnn::memory> diff_dst_mem;
|
||||
|
||||
// convolution primitive
|
||||
std::shared_ptr<mkldnn::convolution_backward_data::primitive_desc>
|
||||
bwd_input_pd;
|
||||
// Conv backward input primitive descriptor and descriptor.
|
||||
std::shared_ptr<ConvBwdDataPd> bwd_input_pd;
|
||||
std::shared_ptr<ConvBwdDataDesc> bwd_input_desc;
|
||||
|
||||
// Primitive descriptor and descriptor for conv fwd
|
||||
std::shared_ptr<ConvFwdPd> fwd_pd;
|
||||
std::shared_ptr<ConvFwdDesc> fwd_desc;
|
||||
|
||||
// Conv bwd input primitive.
|
||||
std::shared_ptr<mkldnn::primitive> conv_bwd_input;
|
||||
|
||||
// desc & prmitive desc
|
||||
std::shared_ptr<mkldnn::convolution_backward_data::desc> bwd_input_desc;
|
||||
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
|
||||
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwd_pd;
|
||||
|
||||
// memory desc: forward & backward can share same memory::desc
|
||||
// Memory descriptors: forward & backward share the same descriptors.
|
||||
std::shared_ptr<memory::desc> diff_src_md;
|
||||
std::shared_ptr<memory::desc> filter_md;
|
||||
std::shared_ptr<memory::desc> diff_dst_md;
|
||||
|
||||
// MKL pipeline
|
||||
// MKL-DNN pipeline for executing primitives.
|
||||
std::shared_ptr<mkldnn::stream> bwd_input_stream;
|
||||
std::vector<mkldnn::primitive> bwd_input_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> bwd_input_primitives_args;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
ConvBwdInputContext()
|
||||
: filter_fmt(memory::format::any),
|
||||
:
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
filter_fmt(memory::format::any),
|
||||
diff_dst_fmt(memory::format::any),
|
||||
#endif
|
||||
diff_src_mem(nullptr),
|
||||
filter_mem(nullptr),
|
||||
diff_dst_mem(nullptr),
|
||||
@ -171,49 +191,53 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
diff_src_md(nullptr),
|
||||
filter_md(nullptr),
|
||||
diff_dst_md(nullptr),
|
||||
bwd_input_stream(nullptr) {}
|
||||
bwd_input_stream(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
void Setup(const MklConvBwdInputParams& convBwdInputDims) {
|
||||
// create memory descriptors for convolution data w/ no specified format
|
||||
context_.diff_src_md.reset(
|
||||
new memory::desc({convBwdInputDims.diff_src_dims}, MklDnnType<T>(),
|
||||
memory::format::any));
|
||||
// Create memory descriptors for conv bwd input without any specified
|
||||
// format so that MKL-DNN can pick an appropriate one depending on the
|
||||
// input parameters.
|
||||
context_.diff_src_md.reset(new memory::desc(
|
||||
{convBwdInputDims.diff_src_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
|
||||
context_.filter_md.reset(new memory::desc(
|
||||
{convBwdInputDims.filter_dims}, MklDnnType<T>(), memory::format::any));
|
||||
context_.diff_dst_md.reset(
|
||||
new memory::desc({convBwdInputDims.diff_dst_dims}, MklDnnType<T>(),
|
||||
memory::format::any));
|
||||
{convBwdInputDims.filter_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
|
||||
context_.diff_dst_md.reset(new memory::desc(
|
||||
{convBwdInputDims.diff_dst_dims}, MklDnnType<T>(), MEMORY_FORMAT::any));
|
||||
|
||||
// create convolution primitives
|
||||
context_.bwd_input_desc.reset(new convolution_backward_data::desc(
|
||||
convolution_direct, *context_.diff_src_md, *context_.filter_md,
|
||||
*context_.diff_dst_md, convBwdInputDims.strides,
|
||||
convBwdInputDims.dilations, convBwdInputDims.padding_left,
|
||||
convBwdInputDims.padding_right, convBwdInputDims.padding));
|
||||
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, convolution_direct, *context_.diff_src_md,
|
||||
// Create descriptors for both conv fwd and conv bwd input.
|
||||
context_.bwd_input_desc.reset(new ConvBwdDataDesc(
|
||||
ALGORITHM::convolution_direct, *context_.diff_src_md,
|
||||
*context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides,
|
||||
convBwdInputDims.dilations, convBwdInputDims.padding_left,
|
||||
convBwdInputDims.padding_right, convBwdInputDims.padding));
|
||||
|
||||
context_.fwd_pd.reset(new convolution_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
context_.fwd_desc.reset(new ConvFwdDesc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct,
|
||||
*context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md,
|
||||
convBwdInputDims.strides, convBwdInputDims.dilations,
|
||||
convBwdInputDims.padding_left, convBwdInputDims.padding_right,
|
||||
convBwdInputDims.padding));
|
||||
|
||||
// create backward conv prim desc
|
||||
context_.bwd_input_pd.reset(new convolution_backward_data::primitive_desc(
|
||||
// Create primitive descriptors for conv fwd and conv bwd input.
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
context_.bwd_input_pd.reset(new ConvBwdDataPd(
|
||||
*context_.bwd_input_desc, cpu_engine_, *context_.fwd_pd));
|
||||
|
||||
// create memory primitive based on dummy data
|
||||
context_.diff_src_mem.reset(new memory(
|
||||
context_.bwd_input_pd.get()->diff_src_primitive_desc(), DummyData));
|
||||
context_.filter_mem.reset(new memory(
|
||||
context_.bwd_input_pd.get()->weights_primitive_desc(), DummyData));
|
||||
context_.diff_dst_mem.reset(new memory(
|
||||
context_.bwd_input_pd.get()->diff_dst_primitive_desc(), DummyData));
|
||||
// Create memory using dummy data.
|
||||
context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_,
|
||||
DummyData));
|
||||
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_,
|
||||
DummyData));
|
||||
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_,
|
||||
DummyData));
|
||||
|
||||
// store the expected memory format
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Store the expected memory format.
|
||||
context_.filter_fmt =
|
||||
static_cast<memory::format>(context_.bwd_input_pd.get()
|
||||
->weights_primitive_desc()
|
||||
@ -224,11 +248,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
->diff_dst_primitive_desc()
|
||||
.desc()
|
||||
.data.format);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// create convolution primitive and add it to net
|
||||
// Create conv bwd input primitive and add it to the net
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.conv_bwd_input.reset(
|
||||
new convolution_backward_data(*context_.bwd_input_pd));
|
||||
context_.bwd_input_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||
{ MKLDNN_ARG_DIFF_SRC,
|
||||
*context_.diff_src_mem }});
|
||||
#else
|
||||
context_.conv_bwd_input.reset(new convolution_backward_data(
|
||||
*context_.bwd_input_pd, *context_.diff_dst_mem, *context_.filter_mem,
|
||||
*context_.diff_src_mem));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
context_.bwd_input_primitives.push_back(*context_.conv_bwd_input);
|
||||
}
|
||||
@ -248,10 +283,10 @@ class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) {
|
||||
MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
|
||||
|
||||
if (do_not_cache) { /* Always allocate primitive */
|
||||
if (do_not_cache) { // Always allocate primitive.
|
||||
conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
|
||||
} else {
|
||||
// look into the pool for reusable primitive
|
||||
// look into the pool for reusable primitive.
|
||||
conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
|
||||
MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
|
||||
convBwdInputDims));
|
||||
@ -308,14 +343,7 @@ class MklConvCustomBackpropInputOp
|
||||
|
||||
void Compute(OpKernelContext* context) {
|
||||
try {
|
||||
MklDnnData<T> filter(&cpu_engine);
|
||||
MklDnnData<T> diff_dst(&cpu_engine);
|
||||
|
||||
// This flag indicate Conv2D or Conv3D
|
||||
bool is_conv2d = (this->strides_.size() == 4);
|
||||
|
||||
// Input tensors
|
||||
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
|
||||
// Input tensors.
|
||||
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
|
||||
const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
|
||||
const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx);
|
||||
@ -350,9 +378,9 @@ class MklConvCustomBackpropInputOp
|
||||
AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor,
|
||||
diff_src_tf_shape, diff_src_mkl_shape,
|
||||
eager_mode);
|
||||
CHECK_NOTNULL(diff_src_tensor);
|
||||
DCHECK(diff_src_tensor != nullptr);
|
||||
|
||||
// if output tensor has more than 0 elements, we need to 0 them out.
|
||||
// If output tensor has more than 0 elements, we need to 0 them out.
|
||||
auto diff_src_data = diff_src_tensor->flat<T>().data();
|
||||
for (size_t i = 0; i < diff_src_tf_shape.num_elements(); ++i) {
|
||||
diff_src_data[i] = static_cast<T>(0);
|
||||
@ -360,28 +388,36 @@ class MklConvCustomBackpropInputOp
|
||||
return;
|
||||
}
|
||||
|
||||
// By default, all dims are in MKL order. Only dims in TF order
|
||||
// are those with postfix tf_order.
|
||||
// By default, all dims are in MKL order except those that are suffixed
|
||||
// with `tf_order`.
|
||||
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
|
||||
memory::dims padding_left, padding_right, dilations, strides;
|
||||
memory::dims fwd_output_dims, fwd_output_dims_tf_order;
|
||||
|
||||
// Get forward convolution parameters.
|
||||
MklDnnConvUtil conv_utl(context, this->strides_, this->padding_,
|
||||
this->data_format_, this->dilations_);
|
||||
conv_utl.GetConvFwdSizesInMklOrder(
|
||||
// Get conv fwd parameters.
|
||||
MklDnnConvUtil conv_util(context, this->strides_, this->padding_,
|
||||
this->data_format_, this->dilations_);
|
||||
conv_util.GetConvFwdSizesInMklOrder(
|
||||
src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims,
|
||||
&strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims,
|
||||
&padding_left, &padding_right, false, is_depthwise);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
// Create Convolution forward descriptor since Convolution backward
|
||||
// API needs it. For that, we first need to create input, filter
|
||||
// and output memory descriptors.
|
||||
bool is_conv2d = (this->strides_.size() == 4);
|
||||
|
||||
// Create conv fwd descriptor since conv bwd input API needs it.
|
||||
// For that, we first need to create input, filter and output memory
|
||||
// descriptors.
|
||||
auto tf_fmt = is_conv2d
|
||||
? TFDataFormatToMklDnnDataFormat(this->data_format_)
|
||||
: TFDataFormatToMklDnn3DDataFormat(this->data_format_);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
||||
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// If filter is in MKL layout, then simply grab filter layout;
|
||||
// otherwise, construct filter in TF layout.
|
||||
// For TF layout, filter is in HWIO format.
|
||||
@ -389,42 +425,47 @@ class MklConvCustomBackpropInputOp
|
||||
filter_mkl_shape.IsMklTensor()
|
||||
? filter_mkl_shape.GetMklLayout()
|
||||
: memory::desc(fwd_filter_dims, MklDnnType<T>(),
|
||||
is_depthwise
|
||||
? memory::hwigo
|
||||
: (is_conv2d ? memory::format::hwio
|
||||
: memory::format::dhwio));
|
||||
is_depthwise ? MEMORY_FORMAT::hwigo
|
||||
: (is_conv2d ? MEMORY_FORMAT::hwio
|
||||
: MEMORY_FORMAT::dhwio));
|
||||
|
||||
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
|
||||
conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
auto diff_dst_md =
|
||||
diff_dst_mkl_shape.IsMklTensor()
|
||||
? diff_dst_mkl_shape.GetMklLayout()
|
||||
: memory::desc(diff_dst_dims, MklDnnType<T>(), tf_fmt);
|
||||
for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
|
||||
: memory::desc(diff_dst_dims, MklDnnType<T>(), MKL_FMT_TAG);
|
||||
|
||||
// The default dilation factor for each dimension is 1 in TF and
|
||||
// 0 in MKL-DNN.
|
||||
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
|
||||
|
||||
MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
|
||||
MklConvBwdInputParams convBwdInputDims(
|
||||
fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations,
|
||||
padding_left, padding_right,
|
||||
TFPaddingToMklDnnPadding(this->padding_));
|
||||
|
||||
// We don't cache those primitves if the env variable
|
||||
// We don't cache those primitives if the environment variable
|
||||
// TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
|
||||
// includes potentialy large buffers. MKL DNN allocates buffers
|
||||
// includes potentialy large buffers. MKL-DNN allocates buffers
|
||||
// in the following cases
|
||||
// 1. Legacy CPU without AVX512/AVX2, or
|
||||
// 2. 1x1 convolution with stride != 1
|
||||
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
|
||||
(MklPrimitiveFactory<T>::IsLegacyPlatform() ||
|
||||
IsConv1x1StrideNot1(fwd_filter_dims, strides));
|
||||
conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
|
||||
do_not_cache);
|
||||
auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
|
||||
|
||||
// allocate output tensor
|
||||
auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
|
||||
MklConvBwdInputPrimitive<T>* conv_bwd_input =
|
||||
MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
|
||||
do_not_cache);
|
||||
|
||||
auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
|
||||
auto diff_src_pd = bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_SRC;
|
||||
auto bwd_diff_src_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
|
||||
auto bwd_diff_src_format = GetOutputFormat(tf_fmt);
|
||||
|
||||
// Allocate output tensor.
|
||||
MklDnnShape diff_src_mkl_shape;
|
||||
diff_src_mkl_shape.SetMklTensor(true);
|
||||
diff_src_mkl_shape.SetMklLayout(&diff_src_pd);
|
||||
@ -443,12 +484,14 @@ class MklConvCustomBackpropInputOp
|
||||
T* diff_src_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_src_tensor->flat<T>().data()));
|
||||
|
||||
// check if filter and diff_dst need reorder
|
||||
// Check if filter and diff_dst need to be reordered.
|
||||
T* filter_data = nullptr;
|
||||
if (fwd_filter_md.data.format !=
|
||||
conv_bwd_input->GetFilterMemoryFormat()) {
|
||||
MklDnnData<T> filter(&cpu_engine_);
|
||||
if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd,
|
||||
conv_bwd_input)) {
|
||||
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
|
||||
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_));
|
||||
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
filter_data =
|
||||
@ -456,16 +499,19 @@ class MklConvCustomBackpropInputOp
|
||||
}
|
||||
|
||||
T* diff_dst_data = nullptr;
|
||||
if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) {
|
||||
MklDnnData<T> diff_dst(&cpu_engine_);
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd,
|
||||
conv_bwd_input)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
|
||||
diff_dst_data = static_cast<T*>(diff_dst.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
diff_dst_data =
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
|
||||
}
|
||||
|
||||
// execute convolution input bwd
|
||||
// Execute conv bwd input primitive.
|
||||
if (!eager_mode) {
|
||||
conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
|
||||
} else {
|
||||
@ -475,18 +521,20 @@ class MklConvCustomBackpropInputOp
|
||||
static_cast<T*>(const_cast<T*>(tmp_tensor.flat<T>().data()));
|
||||
conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data);
|
||||
auto output_tf_md = diff_src_mkl_shape.GetTfLayout();
|
||||
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
|
||||
mkldnn::reorder::primitive_desc reorder_pd =
|
||||
mkldnn::reorder::primitive_desc(diff_src_pd, output_tf_pd);
|
||||
std::vector<mkldnn::primitive> net;
|
||||
memory* tmp_data_mem = new memory(diff_src_pd, tmp_data);
|
||||
memory* dst_data_mem = new memory(output_tf_pd, diff_src_data);
|
||||
net.push_back(
|
||||
mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_);
|
||||
#endif
|
||||
ReorderPd reorder_pd =
|
||||
REORDER_PD_CONSTRUCTOR(diff_src_pd, OUTPUT_TF_MD, cpu_engine_);
|
||||
memory* tmp_data_mem =
|
||||
new MEMORY_CONSTRUCTOR(diff_src_pd, cpu_engine_, tmp_data);
|
||||
memory* dst_data_mem =
|
||||
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data);
|
||||
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
|
||||
cpu_engine_);
|
||||
}
|
||||
|
||||
// delete primitive since it is not cached.
|
||||
// Delete primitive since it is not cached.
|
||||
if (do_not_cache) {
|
||||
delete conv_bwd_input;
|
||||
}
|
||||
@ -501,12 +549,12 @@ class MklConvCustomBackpropInputOp
|
||||
}
|
||||
|
||||
private:
|
||||
const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0;
|
||||
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
|
||||
const int kDilationH = 0, kDilationW = 1;
|
||||
engine cpu_engine = engine(engine::cpu, 0);
|
||||
|
||||
// Validate input shapes.
|
||||
// Function asserts that input shapes are valid.
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
|
||||
// Assert that input shapes are valid.
|
||||
void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
|
||||
const MklDnnShape& filter_mkl_shape,
|
||||
const MklDnnShape& obp_mkl_shape) {
|
||||
@ -532,7 +580,7 @@ class MklConvCustomBackpropInputOp
|
||||
// Get TensorFlow shape of filter tensor.
|
||||
TensorShape MakeFilterTfShape(OpKernelContext* context,
|
||||
const Tensor& filter_tensor) {
|
||||
return GetTfShape(context, kInputIndex_Filter, eager_mode);
|
||||
return GetTfShape(context, kFilterIdx, eager_mode);
|
||||
}
|
||||
|
||||
// Get the Tensorflow shape of Output (diff_src),
|
||||
@ -543,30 +591,29 @@ class MklConvCustomBackpropInputOp
|
||||
return input_shape;
|
||||
}
|
||||
|
||||
// Get the Tensorflow shape of Output (diff_src),
|
||||
// which is same as shape of Conv 'input'.
|
||||
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
|
||||
const memory::dims& fwd_filter_dims) {
|
||||
return fwd_input_dims;
|
||||
}
|
||||
|
||||
// Output layout is Tensorflow's layout in data format order.
|
||||
memory::format GetOutputFormat(const memory::format data_format) {
|
||||
MKL_TENSOR_FORMAT GetOutputFormat(const MKL_TENSOR_FORMAT data_format) {
|
||||
return data_format;
|
||||
}
|
||||
|
||||
// Allocate output tensor.
|
||||
void AllocateOutputTensor(
|
||||
OpKernelContext* context,
|
||||
const convolution_backward_data::primitive_desc& conv_pd,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
memory::format output_tf_format, Tensor** output_tensor) {
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
// TODO(bhavanis): Move this function to mkl_util.h since it is common to
|
||||
// both the forward and backward implementations
|
||||
void AllocateOutputTensor(OpKernelContext* context,
|
||||
const ConvBwdDataPd& conv_pd,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
Tensor** output_tensor) {
|
||||
DCHECK(output_tensor != nullptr);
|
||||
|
||||
// Output primitive descriptor for backward data is diff_src.
|
||||
auto dst_pd = conv_pd.diff_src_primitive_desc();
|
||||
auto dst_pd = conv_pd.PRIMITIVE_DESC_DIFF_SRC;
|
||||
|
||||
// Allocate shape of Mkl tensor.
|
||||
// Allocate shape of MKL tensor.
|
||||
MklDnnShape output_mkl_shape;
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&dst_pd);
|
||||
@ -608,8 +655,10 @@ class MklConvCustomBackpropInputOp
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklConvCustomBackpropInputOp<CPUDevice, T, true, false>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_KERNELS);
|
||||
|
||||
#undef REGISTER_MKL_CPU_KERNELS
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -55,6 +55,7 @@ namespace tensorflow {
|
||||
#define MKLDNN_SIZE_DTYPE int
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
using ConvFwdDesc = mkldnn::convolution_forward::desc;
|
||||
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
||||
|
||||
class MklDnnConvUtil {
|
||||
|
Loading…
x
Reference in New Issue
Block a user