Merge pull request #30549 from Intel-tensorflow:mkldnn-1.0-conv2d-fwd

PiperOrigin-RevId: 260989299
This commit is contained in:
TensorFlower Gardener 2019-07-31 14:28:28 -07:00
commit 644e458325
4 changed files with 477 additions and 165 deletions

View File

@ -353,9 +353,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.mul = "Mul";
csinfo_.squared_difference = "SquaredDifference";
csinfo_.sub = "Sub";
// End - element-wise ops. See note above.
// End - element-wise ops. See note above.
// NOTE: names are alphabetically sorted.
// NOTE: names are alphabetically sorted.
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
@ -389,10 +390,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
{csinfo_.conjugate_transpose,
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
#endif // !ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
kRewriteForLayoutPropagation});
@ -641,6 +644,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back(
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#endif // !ENABLE_MKLDNN_V1
// Disable these two MKL operators for now due to some test failures caused
// by these two ops
/*
@ -653,6 +657,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
*/
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back(
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
@ -753,6 +758,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// CheckForMklOp
FuseConv3D,
CopyAttrsConv});
#endif // !ENABLE_MKLDNN_V1
}
// Standard interface to run pass

View File

@ -47,13 +47,96 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
using mkldnn::convolution_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
using mkldnn::convolution_forward;
using mkldnn::convolution_direct;
namespace tensorflow {
#ifdef ENABLE_MKLDNN_V1
#define ADD_MD add_md
#define ALGORITHM mkldnn::algorithm
#define ALGORITHM_UNDEF ALGORITHM::undef
#define CPU_STREAM(engine) stream(engine)
#define DATA_WITH_ENGINE(data, engine) data, engine
#define DST_MD dst_md
#define ENGINE_CPU engine::kind::cpu
#define GET_DESC get_desc()
#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
{ {dims}, MklDnnType<type>(), memory::format_tag::fm }
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc()
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc()
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \
GET_WEIGHTS_DESC_FROM_OP_PD(op_pd)
#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \
filter_md != op_pd->weights_desc()
#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \
src_md != op_pd->src_desc()
#define MEMORY_CONSTRUCTOR(mem_desc, engine, data) \
memory(mem_desc, engine, data)
#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \
memory(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine, data)
#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_desc, engine) \
memory(mem_desc, engine)
#define MEMORY_DESC memory::desc
#define MEMORY_FORMAT mkldnn::memory::format_tag
#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \
memory::desc({dims}, MklDnnType<type>(), memory::format_tag::fm)
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
#define MKL_TENSOR_FORMAT MklTensorFormat
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
#define MKL_TENSOR_FORMAT_IN_C MKL_TENSOR_FORMAT
#define PRIMITIVE_DESC_BIAS bias_desc()
#define PRIMITIVE_DESC_DST dst_desc()
#define PRIMITIVE_DESC_SRC src_desc()
#define PRIMITIVE_DESC_WEIGHTS weights_desc()
#define REORDER_PD_CONSTRUCTOR(src_md, dst_md, engine) \
mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md)
#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_md, dst_md, engine, prim_attr) \
mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md, prim_attr)
#define SUMMAND_MD summand_md
#else
#define ADD_MD add_pd
#define ALGORITHM mkldnn
#define ALGORITHM_UNDEF ALGORITHM::algorithm_undef
#define CPU_STREAM(engine) stream(stream::kind::eager)
#define DATA_WITH_ENGINE(data, engine) data
#define DST_MD dst_pd
#define ENGINE_CPU engine::cpu
#define GET_DESC get_primitive_desc()
#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
{ {dims}, MklDnnType<type>(), memory::format::fm }
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc()
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc()
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \
op_primitive->GetFilterMemoryFormat()
#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \
filter_md.data.format != op_primitive->GetFilterMemoryFormat()
#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \
src_md.data.format != op_primitive->GetSrcMemoryFormat()
#define MEMORY_CONSTRUCTOR(mem_pd, engine, data) memory(mem_pd, data)
#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \
memory({GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine}, data)
#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_pd, engine) memory(mem_pd)
#define MEMORY_DESC memory::format
#define MEMORY_FORMAT mkldnn::memory::format
#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
#define MKL_TENSOR_FORMAT memory::format
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
#define MKL_TENSOR_FORMAT_IN_C mkldnn_memory_format_t
#define PRIMITIVE_DESC_BIAS bias_primitive_desc()
#define PRIMITIVE_DESC_DST dst_primitive_desc()
#define PRIMITIVE_DESC_SRC src_primitive_desc()
#define PRIMITIVE_DESC_WEIGHTS weights_primitive_desc()
#define REORDER_PD_CONSTRUCTOR(src_pd, dst_pd, engine) \
mkldnn::reorder::primitive_desc(src_pd, dst_pd)
#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_pd, dst_pd, engine, prim_attr) \
mkldnn::reorder::primitive_desc(src_pd, dst_pd, prim_attr)
#define SUMMAND_MD summand_pd
#endif // ENABLE_MKLDNN_V1
// This structure aggregates multiple inputs to Conv2DFwd* methods.
struct MklConvFwdParams {
memory::dims src_dims;
@ -94,9 +177,9 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
class MklConvFwdPrimitive : public MklPrimitive {
public:
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// Create conv primitive
: cpu_engine_(ENGINE_CPU, 0) {
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
// Create convolution primitive
if (context_.conv_fwd == nullptr) {
Setup(convFwdDims);
}
@ -115,19 +198,30 @@ class MklConvFwdPrimitive : public MklPrimitive {
static_cast<void*>(const_cast<Tinput*>(src_data)));
context_.filter_mem->set_data_handle(
static_cast<void*>(const_cast<Tfilter*>(filter_data)));
context_.bias_mem->set_data_handle(
static_cast<void*>(const_cast<Tbias*>(bias_data)));
if (bias_data != nullptr) {
context_.bias_mem->set_data_handle(
static_cast<void*>(const_cast<Tbias*>(bias_data)));
}
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<Toutput*>(dst_data)));
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
context_.fwd_primitives_args.at(i));
}
#else
context_.fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1
// After exec, set data handle back
// After execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
context_.filter_mem->set_data_handle(DummyData);
context_.bias_mem->set_data_handle(DummyData);
if (bias_data != nullptr) {
context_.bias_mem->set_data_handle(DummyData);
}
context_.dst_mem->set_data_handle(DummyData);
return;
}
// Convolution forward execute without bias
@ -136,23 +230,15 @@ class MklConvFwdPrimitive : public MklPrimitive {
// dst_data: output data buffer of dst
void Execute(const Tinput* src_data, const Tfilter* filter_data,
const Toutput* dst_data) {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)));
context_.filter_mem->set_data_handle(
static_cast<void*>(const_cast<Tfilter*>(filter_data)));
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<Toutput*>(dst_data)));
context_.fwd_stream->submit(context_.fwd_primitives);
// After execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
context_.filter_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
Execute(src_data, filter_data, nullptr, dst_data);
}
#ifndef ENABLE_MKLDNN_V1
// In MKL-DNN v1.x, memory format tags only provide a partial description
// of the memory layout. Hence, these functions are disabled for v1.x.
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
#endif // !ENABLE_MKLDNN_V1
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
return context_.fwd_pd;
@ -161,17 +247,19 @@ class MklConvFwdPrimitive : public MklPrimitive {
private:
// Primitive reuse context for Conv2D Fwd op
struct ConvFwdContext {
#ifndef ENABLE_MKLDNN_V1
// Expected memory format for this primitive instance
memory::format src_fmt;
memory::format filter_fmt;
#endif // !ENABLE_MKLDNN_V1
// MKLDNN memory
// MKL-DNN memory
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> filter_mem;
std::shared_ptr<mkldnn::memory> bias_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
// Desc & prmitive desc
// Desc & primitive desc
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
// Memory desc
@ -187,9 +275,16 @@ class MklConvFwdPrimitive : public MklPrimitive {
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
#endif // ENABLE_MKLDNN_V1
ConvFwdContext()
: src_fmt(memory::format::any),
:
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
filter_fmt(memory::format::any),
#endif // !ENABLE_MKLDNN_V1
src_mem(nullptr),
filter_mem(nullptr),
bias_mem(nullptr),
@ -200,34 +295,35 @@ class MklConvFwdPrimitive : public MklPrimitive {
bias_md(nullptr),
fwd_pd(nullptr),
conv_fwd(nullptr),
fwd_stream(nullptr) {}
fwd_stream(nullptr) {
}
};
void Setup(const MklConvFwdParams& convFwdDims) {
// Create memory descriptors for convolution data w/ no specified format
context_.src_md.reset(new memory::desc(
{convFwdDims.src_dims}, MklDnnType<Tinput>(), memory::format::any));
{convFwdDims.src_dims}, MklDnnType<Tinput>(), MEMORY_FORMAT::any));
context_.filter_md.reset(new memory::desc(
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), memory::format::any));
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
context_.dst_md.reset(new memory::desc(
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), memory::format::any));
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), MEMORY_FORMAT::any));
if (!convFwdDims.bias_dims.empty())
context_.bias_md.reset(new memory::desc(
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::any));
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any));
// Create a convolution
// Create a convolution descriptor
if (!convFwdDims.bias_dims.empty()) {
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, convolution_direct, *context_.src_md,
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.filter_md, *context_.bias_md, *context_.dst_md,
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero));
} else {
context_.fwd_desc.reset(new convolution_forward::desc(
prop_kind::forward, convolution_direct, *context_.src_md,
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
convFwdDims.dilations, convFwdDims.padding_left,
convFwdDims.padding_right, padding_kind::zero));
@ -246,7 +342,12 @@ class MklConvFwdPrimitive : public MklPrimitive {
float op_scale = post_op_param.param[0];
float op_alpha = post_op_param.param[1];
float op_beta = post_op_param.param[2];
#ifdef ENABLE_MKLDNN_V1
post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
op_alpha,
#else
post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
#endif // ENABLE_MKLDNN_V1
op_beta);
} else if (post_op_param.name == "sum") {
DCHECK_EQ(post_op_param.param.size(), 1);
@ -271,27 +372,44 @@ class MklConvFwdPrimitive : public MklPrimitive {
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
}
#ifndef ENABLE_MKLDNN_V1
// Store the expected memory format
context_.src_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
context_.filter_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
#endif // !ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.src_mem.reset(
new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
context_.filter_mem.reset(
new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
// Create convolution primitive and add it to net
if (!convFwdDims.bias_dims.empty()) {
context_.bias_mem.reset(new memory(
{{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::x},
cpu_engine_},
DummyData));
context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
convFwdDims.bias_dims, Tbias, x, cpu_engine_, DummyData));
#ifdef ENABLE_MKLDNN_V1
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
{MKLDNN_ARG_BIAS, *context_.bias_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
} else {
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
}
#else
context_.conv_fwd.reset(new convolution_forward(
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
*context_.bias_mem, *context_.dst_mem));
@ -300,9 +418,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
*context_.filter_mem, *context_.dst_mem));
}
#endif // ENABLE_MKLDNN_V1
context_.fwd_primitives.push_back(*context_.conv_fwd);
return;
}
struct ConvFwdContext context_;
@ -566,6 +683,12 @@ class MklConvOp : public OpKernel {
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
: TFDataFormatToMklDnn3DDataFormat(data_format_);
#ifdef ENABLE_MKLDNN_V1
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
DCHECK_NE(mkl_fmt_tag, memory::format_tag::undef);
#endif // ENABLE_MKLDNN_V1
// If input is in MKL layout, then simply grab the layout; otherwise,
// construct TF layout for input.
// For constructing TF layout for input, although input shape (src_dims)
@ -573,18 +696,22 @@ class MklConvOp : public OpKernel {
// TF layout depending on the data format:
// Conv2D: NHWC or NCHW
// Conv3D: NDHWC or NCDHW
auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
auto src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
#ifdef ENABLE_MKLDNN_V1
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
#else
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
#endif // ENABLE_MKLDNN_V1
src.SetUsrMem(src_md, &src_tensor);
// Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for
// depthwise/group convolutions.
auto filter_format = is_conv2d ? (is_depthwise ? memory::format::hwigo
: memory::format::hwio)
: memory::format::dhwio;
auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo
: MEMORY_FORMAT::hwio)
: MEMORY_FORMAT::dhwio;
DCHECK(!filter_mkl_shape.IsMklTensor());
auto filter_md =
@ -593,7 +720,7 @@ class MklConvOp : public OpKernel {
: memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format);
filter.SetUsrMem(filter_md, &filter_tensor);
// MKLDNN dilations start from 0.
// MKL-DNN dilations start from 0.
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
// In some cases, primitive descriptor could potentially contain
@ -643,10 +770,11 @@ class MklConvOp : public OpKernel {
// Check whether src and filter need to be reordered
Tinput* src_data = nullptr;
if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
// Reorder src
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
} else {
src_data = static_cast<Tinput*>(
@ -654,7 +782,7 @@ class MklConvOp : public OpKernel {
}
Tfilter* filter_data = nullptr;
if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
bool is_filter_cached = false;
// If filter is a constant, we can avoid the conversion of filter from
// Tensorflow format to MKL format by caching the filter when it is
@ -664,21 +792,26 @@ class MklConvOp : public OpKernel {
if (IsFilterCacheEmpty(context)) {
// Cache filter if it is not already cached.
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
#ifdef ENABLE_MKLDNN_V1
filter, filter_md, filter_mkl_shape);
#else
filter, filter_md);
#endif // ENABLE_MKLDNN_V1
}
filter_data =
GetCachedFilter(context, conv_fwd->GetFilterMemoryFormat());
filter_data = GetCachedFilter(
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
is_filter_cached = (filter_data != nullptr);
}
if (!is_filter_cached) {
filter.SetUsrMem(filter_md, &filter_tensor);
if (filter_out_tensor == nullptr) {
filter.CheckReorderToOpMem(
conv_fwd_pd.get()->weights_primitive_desc());
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
} else {
filter.CheckReorderToOpMem(
conv_fwd_pd.get()->weights_primitive_desc(),
filter.GetTensorBuffer(filter_out_tensor));
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
cpu_engine_));
}
filter_data =
static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
@ -787,7 +920,7 @@ class MklConvOp : public OpKernel {
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
// checking `fuse_biasadd_` flag.
if (fuse_add_) {
params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}});
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
}
if (fuse_activation_) {
params.post_op_params.push_back(
@ -808,29 +941,35 @@ class MklConvOp : public OpKernel {
virtual void AllocateOutputTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc,
const memory::dims& output_dims_mkl_order,
memory::format output_tf_format,
MKL_TENSOR_FORMAT output_tf_format,
Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
DCHECK(output_tensor);
#ifdef ENABLE_MKLDNN_V1
auto dst_md = conv_prim_desc.dst_desc();
#else
auto dst_pd = conv_prim_desc.dst_primitive_desc();
auto dst_md = dst_pd.desc();
#endif // ENABLE_MKLDNN_V1
if (!std::is_same<Ttemp_output, Toutput>::value) {
dst_md.data.data_type =
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
#ifndef ENABLE_MKLDNN_V1
dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
#endif // !ENABLE_MKLDNN_V1
}
// Allocate shape of Mkl tensor.
// Allocate shape of MKL tensor
MklDnnShape output_mkl_shape;
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetMklLayout(&DST_MD);
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
output_dims_mkl_order, output_tf_format);
// Allocate shape of TF tensor.
// Allocate shape of TF tensor
TensorShape output_tf_shape;
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
output_tf_shape, output_mkl_shape);
if (fuse_add_) {
@ -838,32 +977,40 @@ class MklConvOp : public OpKernel {
MklDnnShape add_mkl_shape;
GetMklShape(context, kInputIndex_Add, &add_mkl_shape);
// Check if need reorder
// Check if reorder is needed
if (add_mkl_shape == output_mkl_shape) {
CHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
DCHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
} else {
auto add_md =
add_mkl_shape.IsMklTensor()
? add_mkl_shape.GetMklLayout()
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
output_mkl_shape.GetTfDataFormat());
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
void* add_buf = static_cast<void*>(
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
void* dst_buf =
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
auto add = new memory(add_pd, add_buf);
auto dst = new memory(dst_pd, dst_buf);
auto reorder_desc = mkldnn::reorder::primitive_desc(add_pd, dst_pd);
std::vector<mkldnn::primitive> net;
net.push_back(mkldnn::reorder(reorder_desc, *add, *dst));
stream(stream::kind::eager).submit(net).wait();
if (add_mkl_shape.IsMklTensor()) {
auto add_md = add_mkl_shape.GetMklLayout();
} else {
#ifdef ENABLE_MKLDNN_V1
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
output_mkl_shape.GetTfDataFormat());
DCHECK_NE(output_format_tag, memory::format_tag::undef);
auto add_md = memory::desc(output_dims_mkl_order,
MklDnnType<Toutput>(), output_format_tag);
#else
auto add_md =
memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
output_mkl_shape.GetTfDataFormat());
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
#endif // ENABLE_MKLDNN_V1
void* add_buf = static_cast<void*>(
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
void* dst_buf =
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
auto add = new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
auto dst = new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
auto reorder_desc =
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
CreateAndExecuteReorder(reorder_desc, *add, *dst, this->cpu_engine_);
}
}
}
}
engine cpu_engine_ = engine(engine::cpu, 0);
engine cpu_engine_ = engine(ENGINE_CPU, 0);
private:
std::vector<int32> strides_;
@ -883,7 +1030,7 @@ class MklConvOp : public OpKernel {
bool fuse_add_ = false;
float relu_up_bound_ = 0.0;
mkldnn::algorithm activation_alg_ = mkldnn::algorithm_undef;
mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF;
int input_index_pad_ = 2;
@ -892,15 +1039,27 @@ class MklConvOp : public OpKernel {
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
const int kDilationH = 0, kDilationW = 1;
MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat(
const MklDnnShape* filter_mkl_shape,
const ConvFwdPd& conv_prim_desc) const {
#ifdef ENABLE_MKLDNN_V1
DCHECK(filter_mkl_shape);
return filter_mkl_shape->GetTfDataFormat();
#else
return conv_prim_desc.weights_primitive_desc().desc().data.format;
#endif // ENABLE_MKLDNN_V1
}
// Allocate persistent tensors for cached filter data and
// cached filter memory descriptor (data format)
void AllocatePersistentTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc,
Tensor** filter_tensor) {
Tensor** filter_tensor,
const MklDnnShape* filter_mkl_shape) {
DCHECK(filter_tensor);
TensorShape filter_tf_shape;
filter_tf_shape.AddDim(
(conv_prim_desc.weights_primitive_desc().get_size() / sizeof(Tfilter)));
(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter)));
OP_REQUIRES_OK(context, context->allocate_persistent(
DataTypeToEnum<Tfilter>::value, filter_tf_shape,
&cached_filter_data_ptensor_, filter_tensor));
@ -908,37 +1067,44 @@ class MklConvOp : public OpKernel {
Tensor* second_tensor = nullptr;
TensorShape filter_mkl_format;
filter_mkl_format.AddDim(
sizeof(conv_prim_desc.weights_primitive_desc().desc().data.format) /
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
sizeof(DT_INT32));
OP_REQUIRES_OK(context, context->allocate_persistent(
DT_INT32, filter_mkl_format,
&cached_filter_md_ptensor_, &second_tensor));
second_tensor->scalar<int32>()() =
conv_prim_desc.weights_primitive_desc().desc().data.format;
second_tensor->scalar<int32>()() = static_cast<int32>(
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
}
void AllocatePersistentTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc,
Tensor** filter_tensor) {
AllocatePersistentTensor(context, conv_prim_desc, filter_tensor, nullptr);
}
void AllocateFilterOutputTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc,
const memory::dims& filter_dims_tf_order,
Tensor** filter_tensor) {
CHECK_NOTNULL(filter_tensor);
auto filter_pd = conv_prim_desc.weights_primitive_desc();
DCHECK(filter_tensor);
auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS;
// Allocate shape of Mkl tensor.
// Allocate shape of MKL tensor
MklDnnShape filter_mkl_shape;
filter_mkl_shape.SetMklTensor(true);
filter_mkl_shape.SetMklLayout(&filter_pd);
filter_mkl_shape.SetMklLayout(&filter_md);
filter_mkl_shape.SetElemType(MklDnnType<Tfilter>());
// The format of the filter is actually OIhw8i8o, but TF doesn't support
// this format. Just use format::blocked for now because the layout
// is stored in the MKL data.
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
filter_dims_tf_order, memory::format::blocked);
filter_dims_tf_order,
MKL_TENSOR_FORMAT_BLOCKED);
// Allocate the data space for the filter to propagate as TF tensor.
TensorShape filter_tf_shape;
filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(Tfilter)));
filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter)));
AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
filter_tf_shape, filter_mkl_shape);
@ -951,20 +1117,46 @@ class MklConvOp : public OpKernel {
MklDnnData<Tbias>* bias,
MklDnnData<Toutput>* output,
Tensor* filter_out_tensor) {
CHECK_NOTNULL(filter_out_tensor);
DCHECK(filter_out_tensor);
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution. No need to check for output
// reorder as we propagate output layout to the next layer.
src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc());
src->CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
// rather than re-order to a temp buffer, reorder directly to the
// Rather than re-ordering to a temp buffer, reorder directly to the
// filter output tensor
filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS,
filter->GetTensorBuffer(filter_out_tensor));
// Create convolution primitive and add it to net.
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> net_args;
if (bias) {
DCHECK(fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc));
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
{MKLDNN_ARG_BIAS, bias->GetOpMem()},
{ MKLDNN_ARG_DST,
output->GetOpMem() }});
} else {
DCHECK(!fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc));
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
{ MKLDNN_ARG_DST,
output->GetOpMem() }});
}
stream cpu_stream(cpu_engine_);
DCHECK_EQ(net.size(), net_args.size());
for (size_t i = 0; i < net.size(); ++i) {
net.at(i).execute(cpu_stream, net_args.at(i));
}
cpu_stream.wait();
#else
if (bias) {
DCHECK(fuse_biasadd_);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
@ -976,8 +1168,8 @@ class MklConvOp : public OpKernel {
filter->GetOpMem(),
output->GetOpMem()));
}
stream(stream::kind::eager).submit(net).wait();
#endif // ENABLE_MKLDNN_V1
}
// LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
@ -990,8 +1182,55 @@ class MklConvOp : public OpKernel {
return (cached_filter_data_tensor.NumElements() == 0);
}
// Cache the converted filter in a persistent tensor.
// Only one thread can execute this method at any given time.
// Cache the converted filter in a persistent tensor.
// Only one thread can execute this method at any given time.
#ifdef ENABLE_MKLDNN_V1
void CacheFilter(OpKernelContext* context,
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
Tfilter* filter_data, const Tensor& filter_tensor,
MklDnnData<Tfilter>& filter, const memory::desc& filter_md,
const MklDnnShape& filter_mkl_shape) LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
const Tensor& cached_filter_data_tensor =
*cached_filter_data_ptensor_.AccessTensor(context);
// If filter is already cached, there's nothing to do.
if (cached_filter_data_tensor.NumElements() > 0) {
return;
}
// Otherwise, cache filter
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
this->cpu_engine_);
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
Tensor* filter_tensor_ptr = nullptr;
AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr,
&filter_mkl_shape);
void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size();
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
}
bool AreMemoryDescriptorsEqual(const memory::desc& filter_md,
const Tensor& cached_filter_md) {
auto filter_md_data = filter_md.data;
const char* filter_data = reinterpret_cast<const char*>(&filter_md_data);
auto cached_filter_md_data = cached_filter_md.scalar<int64>()();
const char* cached_filter_data =
reinterpret_cast<const char*>(&cached_filter_md_data);
for (size_t i = 0; i < sizeof(filter_md_data); ++i) {
if (*filter_data++ != *cached_filter_data++) {
return false;
}
}
return true;
}
#else
void CacheFilter(OpKernelContext* context,
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
Tfilter* filter_data, const Tensor& filter_tensor,
@ -1018,22 +1257,26 @@ class MklConvOp : public OpKernel {
filter.GetOpMem().get_primitive_desc().get_size();
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
}
#endif // ENABLE_MKLDNN_V1
Tfilter* GetCachedFilter(OpKernelContext* context,
const memory::format& filter_mf)
LOCKS_EXCLUDED(mu_) {
const MEMORY_DESC& filter_md) LOCKS_EXCLUDED(mu_) {
tf_shared_lock lock(mu_);
const Tensor& cached_filter_data =
*cached_filter_data_ptensor_.AccessTensor(context);
const Tensor& cached_filter_md =
*cached_filter_md_ptensor_.AccessTensor(context);
// Check if the memory descriptor of the cached weights is same as
// filter_mf. If so, we can used the cached weights; otherwise
// return NULL.
// TODO (bhavanis): Do we need to cast filter_mf before the check?
// Check if the memory descriptor of the cached weights is same as
// filter_md. If so, we can used the cached weights; otherwise
// return NULL.
#ifdef ENABLE_MKLDNN_V1
if (cached_filter_md.scalar<int64>().size() &&
AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) {
#else
if (cached_filter_md.scalar<int32>().size() &&
cached_filter_md.scalar<int32>()() == filter_mf) {
cached_filter_md.scalar<int32>()() == filter_md) {
#endif // ENABLE_MKLDNN_V1
return static_cast<Tfilter*>(
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
}
@ -1069,26 +1312,26 @@ class MklFusedConvOp
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"Relu"}) {
this->set_fuse_activation(true, mkldnn::eltwise_relu);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
} else if (fused_ops == std::vector<string>{"Relu6"}) {
this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
} else if (fused_ops == std::vector<string>{"Elu"}) {
this->set_fuse_activation(true, mkldnn::eltwise_elu);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, mkldnn::eltwise_relu);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_activation(true, mkldnn::eltwise_elu);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
OP_REQUIRES(context, num_args == 1,
errors::InvalidArgument(
"Fused Conv2D must have one extra argument: bias."));
@ -1102,7 +1345,7 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_add(true);
this->set_fuse_activation(true, mkldnn::eltwise_relu);
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
OP_REQUIRES(
context, num_args == 2,
errors::InvalidArgument(
@ -1110,7 +1353,7 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
this->set_fuse_biasadd(true);
this->set_fuse_add(true);
this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
OP_REQUIRES(
context, num_args == 2,
errors::InvalidArgument(
@ -1118,7 +1361,7 @@ class MklFusedConvOp
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
this->set_fuse_biasadd(true);
this->set_fuse_add(true);
this->set_fuse_activation(true, mkldnn::eltwise_elu);
this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
OP_REQUIRES(
context, num_args == 2,
errors::InvalidArgument(
@ -1274,7 +1517,7 @@ class MklQuantizedConv2DOp
(255.0f * 127.0f * output_range);
}
params.post_op_params.push_back(
{"output_scale", mkldnn::algorithm_undef, scales});
{"output_scale", ALGORITHM_UNDEF, scales});
}
}
@ -1293,7 +1536,6 @@ class MklQuantizedConv2DOp
const float* min_filter = min_filter_vector.flat<float>().data();
const float* max_filter = max_filter_vector.flat<float>().data();
std::vector<mkldnn::primitive> net;
if (bias_enabled) {
if (std::is_same<Tbias, qint32>::value) {
return static_cast<Tbias*>(
@ -1315,21 +1557,21 @@ class MklQuantizedConv2DOp
} else {
bias_attr.set_output_scales(1, scales);
}
auto bias_pd =
memory::primitive_desc({{static_cast<int>(bias_tensor.NumElements())},
MklDnnType<Tbias>(),
memory::format::x},
this->cpu_engine_);
auto bias_md =
MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()),
Tbias, x, this->cpu_engine_);
void* bias_buf = static_cast<void*>(
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
input_bias_ = new memory(bias_pd, bias_buf);
scaled_bias_ = new memory(conv_fwd_pd->bias_primitive_desc());
auto reorder_desc = mkldnn::reorder::primitive_desc(
input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(),
input_bias_ =
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
scaled_bias_ = new MEMORY_CONSTRUCTOR_WITHOUT_DATA(
conv_fwd_pd->PRIMITIVE_DESC_BIAS, this->cpu_engine_);
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
bias_attr);
net.push_back(mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_));
stream(stream::kind::eager).submit(net).wait();
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
this->cpu_engine_);
return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
} else {
return nullptr;
@ -1358,7 +1600,7 @@ class MklQuantizedConv2DReluOp
MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, bias_enabled,
is_depthwise>::ExtendConvFwdParams(context, params);
params.post_op_params.push_back(
{"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}});
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
}
};
@ -1415,23 +1657,23 @@ class MklQuantizedConv2DSumReluOp
// If it is not then it is DT_INT8 and is scaled appropriately.
if (summand_type == DT_QUINT8)
params.post_op_params.push_back(
{"sum", mkldnn::algorithm_undef, {scale_summand / scale_output}});
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}});
else
params.post_op_params.push_back(
{"sum",
mkldnn::algorithm_undef,
ALGORITHM_UNDEF,
{255.0f * scale_summand / (scale_output * 127.0f)}});
} else {
params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}});
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
}
params.post_op_params.push_back(
{"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}});
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
}
void AllocateOutputTensor(OpKernelContext* context,
const ConvFwdPd& conv_prim_desc,
const memory::dims& output_dims_mkl_order,
memory::format output_tf_format,
MKL_TENSOR_FORMAT output_tf_format,
Tensor** output_tensor) override {
int summand_idx = context->num_inputs() / 2 - 1;
if (std::is_same<Toutput, quint8>::value) {
@ -1503,20 +1745,22 @@ class MklQuantizedConv2DSumReluOp
summand_mkl_shape.IsMklTensor()
? summand_mkl_shape.GetMklLayout()
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
memory::format::nhwc);
MEMORY_FORMAT::nhwc);
#ifndef ENABLE_MKLDNN_V1
auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
#endif // !ENABLE_MKLDNN_V1
void* summand_buf =
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
void* dst_buf =
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
summand_ = new memory(summand_pd, summand_buf);
dst_ = new memory(conv_prim_desc.dst_primitive_desc(), dst_buf);
auto reorder_desc = mkldnn::reorder::primitive_desc(
summand_pd, conv_prim_desc.dst_primitive_desc(), reorder_attr);
std::vector<mkldnn::primitive> net;
net.push_back(mkldnn::reorder(reorder_desc, *summand_, *dst_));
stream(stream::kind::eager).submit(net).wait();
summand_ =
new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf);
dst_ = new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST,
this->cpu_engine_, dst_buf);
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
reorder_attr);
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_);
}
memory* summand_ = nullptr;
@ -1970,5 +2214,36 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED);
TF_CALL_float(REGISTER_MKL_CPU_3D);
TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
#undef ADD_MD
#undef ALGORITHM
#undef ALGORITHM_UNDEF
#undef CPU_STREAM
#undef DATA_WITH_ENGINE
#undef DST_MD
#undef ENGINE_CPU
#undef GET_DESC
#undef GET_MEMORY_DESC_CONSTRUCTOR
#undef GET_SRC_DESC_FROM_OP_PD
#undef GET_WEIGHTS_DESC_FROM_OP_PD
#undef GET_WEIGHTS_FORMAT_FROM_OP_PD
#undef IS_FILTER_REORDER_NEEDED
#undef IS_SRC_REORDER_NEEDED
#undef MEMORY_CONSTRUCTOR
#undef MEMORY_CONSTRUCTOR_USING_MEM_PD
#undef MEMORY_CONSTRUCTOR_WITHOUT_DATA
#undef MEMORY_DESC
#undef MEMORY_FORMAT
#undef MEMORY_PD_CONSTRUCTOR
#undef MEMORY_PD_WITHOUT_DATA
#undef MKL_TENSOR_FORMAT
#undef MKL_TENSOR_FORMAT_BLOCKED
#undef MKL_TENSOR_FORMAT_IN_C
#undef PRIMITIVE_DESC_BIAS
#undef PRIMITIVE_DESC_DST
#undef PRIMITIVE_DESC_SRC
#undef PRIMITIVE_DESC_WEIGHTS
#undef REORDER_PD_CONSTRUCTOR
#undef REORDER_PD_CONSTRUCTOR_WITH_ATTR
#undef SUMMAND_MD
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -40,13 +40,21 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#ifndef ENABLE_MKLDNN_V1
using mkldnn::convolution_direct;
#endif // !ENABLE_MKLDNN_V1
using mkldnn::convolution_forward;
using mkldnn::prop_kind;
using mkldnn::stream;
namespace tensorflow {
#ifdef ENABLE_MKLDNN_V1
#define MKLDNN_SIZE_DTYPE long int
#else
#define MKLDNN_SIZE_DTYPE int
#endif // ENABLE_MKLDNN_V1
class MklDnnConvUtil {
protected:
OpKernelContext* context_; // We don't own this.
@ -137,7 +145,7 @@ class MklDnnConvUtil {
int input_cols = static_cast<int>(input_cols_raw);
// MKL-DNN always requires input in NCHW format Conv2D.
std::vector<int> mkldnn_sizes(4, -1);
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
@ -161,7 +169,7 @@ class MklDnnConvUtil {
int input_cols = static_cast<int>(input_cols_raw);
// MKL-DNN always requires input in NCDHW format for Conv3D.
std::vector<int> mkldnn_sizes(5, -1);
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
@ -225,7 +233,7 @@ class MklDnnConvUtil {
// GOIHW = (group, out_depth, in_depth, rows, cols)
// Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1
if (is_depthwise) {
std::vector<int> mkldnn_sizes(5, -1);
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth;
mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth;
mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1;
@ -234,7 +242,7 @@ class MklDnnConvUtil {
*filter_dims = mkldnn_sizes;
} else {
std::vector<int> mkldnn_sizes(4, -1);
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth;
mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth;
mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
@ -262,7 +270,7 @@ class MklDnnConvUtil {
// MKL-DNN always needs filter in OIDHW format.
// OIDHW = (out_depth, in_depth, planes, rows, cols)
std::vector<int> mkldnn_sizes(5, -1);
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth;
mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth;
mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
@ -455,14 +463,14 @@ class MklDnnConvUtil {
if (is_conv2d) {
// For Conv2D, MKL-DNN always needs output in NCHW format.
std::vector<int> mkldnn_sizes(4, -1);
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
*output_dims_mkl_order = mkldnn_sizes;
} else {
std::vector<int> mkldnn_sizes(5, -1);
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
@ -624,6 +632,8 @@ class MklDummyOp : public OpKernel {
}
};
#undef MKLDNN_SIZE_DTYPE
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_

View File

@ -1194,6 +1194,27 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
return memory::desc(md);
}
inline void CreateAndExecuteReorder(const reorder::primitive_desc& reorder_desc,
const memory& src_mem,
const memory& dst_mem,
const engine& engine) {
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
net.push_back(mkldnn::reorder(reorder_desc));
std::vector<MemoryArgsMap> net_args;
net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
DCHECK_EQ(net.size(), net_args.size());
stream cpu_stream(engine);
for (size_t i = 0; i < net.size(); ++i) {
net.at(i).execute(cpu_stream, net_args.at(i));
}
cpu_stream.wait();
#else
net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem));
stream(stream::kind::eager).submit(net).wait();
#endif // ENABLE_MKLDNN_V1
}
template <typename T>
inline primitive FindOrCreateReorder(const memory* from, const memory* to);