Merge pull request #30549 from Intel-tensorflow:mkldnn-1.0-conv2d-fwd
PiperOrigin-RevId: 260989299
This commit is contained in:
commit
644e458325
@ -353,9 +353,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
csinfo_.mul = "Mul";
|
||||
csinfo_.squared_difference = "SquaredDifference";
|
||||
csinfo_.sub = "Sub";
|
||||
// End - element-wise ops. See note above.
|
||||
// End - element-wise ops. See note above.
|
||||
|
||||
// NOTE: names are alphabetically sorted.
|
||||
// NOTE: names are alphabetically sorted.
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
@ -389,10 +390,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
{csinfo_.conjugate_transpose,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.conv2d,
|
||||
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
|
||||
CopyAttrsConvCheckConstFilter, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
@ -641,6 +644,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back(
|
||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
// Disable these two MKL operators for now due to some test failures caused
|
||||
// by these two ops
|
||||
/*
|
||||
@ -653,6 +657,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
*/
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back(
|
||||
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
@ -753,6 +758,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// CheckForMklOp
|
||||
FuseConv3D,
|
||||
CopyAttrsConv});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// Standard interface to run pass
|
||||
|
@ -47,13 +47,96 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::convolution_direct;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define ADD_MD add_md
|
||||
#define ALGORITHM mkldnn::algorithm
|
||||
#define ALGORITHM_UNDEF ALGORITHM::undef
|
||||
#define CPU_STREAM(engine) stream(engine)
|
||||
#define DATA_WITH_ENGINE(data, engine) data, engine
|
||||
#define DST_MD dst_md
|
||||
#define ENGINE_CPU engine::kind::cpu
|
||||
#define GET_DESC get_desc()
|
||||
#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
|
||||
{ {dims}, MklDnnType<type>(), memory::format_tag::fm }
|
||||
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc()
|
||||
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc()
|
||||
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(op_pd)
|
||||
#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \
|
||||
filter_md != op_pd->weights_desc()
|
||||
#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \
|
||||
src_md != op_pd->src_desc()
|
||||
#define MEMORY_CONSTRUCTOR(mem_desc, engine, data) \
|
||||
memory(mem_desc, engine, data)
|
||||
#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \
|
||||
memory(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine, data)
|
||||
#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_desc, engine) \
|
||||
memory(mem_desc, engine)
|
||||
#define MEMORY_DESC memory::desc
|
||||
#define MEMORY_FORMAT mkldnn::memory::format_tag
|
||||
#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \
|
||||
memory::desc({dims}, MklDnnType<type>(), memory::format_tag::fm)
|
||||
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
|
||||
#define MKL_TENSOR_FORMAT MklTensorFormat
|
||||
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
|
||||
#define MKL_TENSOR_FORMAT_IN_C MKL_TENSOR_FORMAT
|
||||
#define PRIMITIVE_DESC_BIAS bias_desc()
|
||||
#define PRIMITIVE_DESC_DST dst_desc()
|
||||
#define PRIMITIVE_DESC_SRC src_desc()
|
||||
#define PRIMITIVE_DESC_WEIGHTS weights_desc()
|
||||
#define REORDER_PD_CONSTRUCTOR(src_md, dst_md, engine) \
|
||||
mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md)
|
||||
#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_md, dst_md, engine, prim_attr) \
|
||||
mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md, prim_attr)
|
||||
#define SUMMAND_MD summand_md
|
||||
#else
|
||||
#define ADD_MD add_pd
|
||||
#define ALGORITHM mkldnn
|
||||
#define ALGORITHM_UNDEF ALGORITHM::algorithm_undef
|
||||
#define CPU_STREAM(engine) stream(stream::kind::eager)
|
||||
#define DATA_WITH_ENGINE(data, engine) data
|
||||
#define DST_MD dst_pd
|
||||
#define ENGINE_CPU engine::cpu
|
||||
#define GET_DESC get_primitive_desc()
|
||||
#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \
|
||||
{ {dims}, MklDnnType<type>(), memory::format::fm }
|
||||
#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc()
|
||||
#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc()
|
||||
#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \
|
||||
op_primitive->GetFilterMemoryFormat()
|
||||
#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \
|
||||
filter_md.data.format != op_primitive->GetFilterMemoryFormat()
|
||||
#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \
|
||||
src_md.data.format != op_primitive->GetSrcMemoryFormat()
|
||||
#define MEMORY_CONSTRUCTOR(mem_pd, engine, data) memory(mem_pd, data)
|
||||
#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \
|
||||
memory({GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine}, data)
|
||||
#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_pd, engine) memory(mem_pd)
|
||||
#define MEMORY_DESC memory::format
|
||||
#define MEMORY_FORMAT mkldnn::memory::format
|
||||
#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \
|
||||
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
|
||||
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
|
||||
#define MKL_TENSOR_FORMAT memory::format
|
||||
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
|
||||
#define MKL_TENSOR_FORMAT_IN_C mkldnn_memory_format_t
|
||||
#define PRIMITIVE_DESC_BIAS bias_primitive_desc()
|
||||
#define PRIMITIVE_DESC_DST dst_primitive_desc()
|
||||
#define PRIMITIVE_DESC_SRC src_primitive_desc()
|
||||
#define PRIMITIVE_DESC_WEIGHTS weights_primitive_desc()
|
||||
#define REORDER_PD_CONSTRUCTOR(src_pd, dst_pd, engine) \
|
||||
mkldnn::reorder::primitive_desc(src_pd, dst_pd)
|
||||
#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_pd, dst_pd, engine, prim_attr) \
|
||||
mkldnn::reorder::primitive_desc(src_pd, dst_pd, prim_attr)
|
||||
#define SUMMAND_MD summand_pd
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// This structure aggregates multiple inputs to Conv2DFwd* methods.
|
||||
struct MklConvFwdParams {
|
||||
memory::dims src_dims;
|
||||
@ -94,9 +177,9 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
|
||||
class MklConvFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
context_.fwd_stream.reset(new stream(stream::kind::eager));
|
||||
// Create conv primitive
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
// Create convolution primitive
|
||||
if (context_.conv_fwd == nullptr) {
|
||||
Setup(convFwdDims);
|
||||
}
|
||||
@ -115,19 +198,30 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tfilter*>(filter_data)));
|
||||
context_.bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
||||
if (bias_data != nullptr) {
|
||||
context_.bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
||||
}
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
}
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// After exec, set data handle back
|
||||
// After execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.filter_mem->set_data_handle(DummyData);
|
||||
context_.bias_mem->set_data_handle(DummyData);
|
||||
if (bias_data != nullptr) {
|
||||
context_.bias_mem->set_data_handle(DummyData);
|
||||
}
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Convolution forward execute without bias
|
||||
@ -136,23 +230,15 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||
const Toutput* dst_data) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tfilter*>(filter_data)));
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
|
||||
// After execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.filter_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
Execute(src_data, filter_data, nullptr, dst_data);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// In MKL-DNN v1.x, memory format tags only provide a partial description
|
||||
// of the memory layout. Hence, these functions are disabled for v1.x.
|
||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
|
||||
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
|
||||
return context_.fwd_pd;
|
||||
@ -161,17 +247,19 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
private:
|
||||
// Primitive reuse context for Conv2D Fwd op
|
||||
struct ConvFwdContext {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Expected memory format for this primitive instance
|
||||
memory::format src_fmt;
|
||||
memory::format filter_fmt;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// MKLDNN memory
|
||||
// MKL-DNN memory
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> filter_mem;
|
||||
std::shared_ptr<mkldnn::memory> bias_mem;
|
||||
std::shared_ptr<mkldnn::memory> dst_mem;
|
||||
|
||||
// Desc & prmitive desc
|
||||
// Desc & primitive desc
|
||||
std::shared_ptr<mkldnn::convolution_forward::desc> fwd_desc;
|
||||
|
||||
// Memory desc
|
||||
@ -187,9 +275,16 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
ConvFwdContext()
|
||||
: src_fmt(memory::format::any),
|
||||
:
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_fmt(memory::format::any),
|
||||
filter_fmt(memory::format::any),
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
src_mem(nullptr),
|
||||
filter_mem(nullptr),
|
||||
bias_mem(nullptr),
|
||||
@ -200,34 +295,35 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
bias_md(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
conv_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
fwd_stream(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
void Setup(const MklConvFwdParams& convFwdDims) {
|
||||
// Create memory descriptors for convolution data w/ no specified format
|
||||
context_.src_md.reset(new memory::desc(
|
||||
{convFwdDims.src_dims}, MklDnnType<Tinput>(), memory::format::any));
|
||||
{convFwdDims.src_dims}, MklDnnType<Tinput>(), MEMORY_FORMAT::any));
|
||||
|
||||
context_.filter_md.reset(new memory::desc(
|
||||
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), memory::format::any));
|
||||
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
|
||||
|
||||
context_.dst_md.reset(new memory::desc(
|
||||
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), memory::format::any));
|
||||
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), MEMORY_FORMAT::any));
|
||||
|
||||
if (!convFwdDims.bias_dims.empty())
|
||||
context_.bias_md.reset(new memory::desc(
|
||||
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::any));
|
||||
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any));
|
||||
|
||||
// Create a convolution
|
||||
// Create a convolution descriptor
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, convolution_direct, *context_.src_md,
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.bias_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
} else {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, convolution_direct, *context_.src_md,
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
|
||||
convFwdDims.dilations, convFwdDims.padding_left,
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
@ -246,7 +342,12 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
float op_scale = post_op_param.param[0];
|
||||
float op_alpha = post_op_param.param[1];
|
||||
float op_beta = post_op_param.param[2];
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu,
|
||||
op_alpha,
|
||||
#else
|
||||
post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha,
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
op_beta);
|
||||
} else if (post_op_param.name == "sum") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
@ -271,27 +372,44 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Store the expected memory format
|
||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
|
||||
|
||||
context_.filter_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData));
|
||||
context_.filter_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
|
||||
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
|
||||
|
||||
// Create convolution primitive and add it to net
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
context_.bias_mem.reset(new memory(
|
||||
{{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format::x},
|
||||
cpu_engine_},
|
||||
DummyData));
|
||||
context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
|
||||
convFwdDims.bias_dims, Tbias, x, cpu_engine_, DummyData));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||
context_.fwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||
{MKLDNN_ARG_BIAS, *context_.bias_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
} else {
|
||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||
context_.fwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
}
|
||||
#else
|
||||
context_.conv_fwd.reset(new convolution_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
|
||||
*context_.bias_mem, *context_.dst_mem));
|
||||
@ -300,9 +418,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
|
||||
*context_.filter_mem, *context_.dst_mem));
|
||||
}
|
||||
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
context_.fwd_primitives.push_back(*context_.conv_fwd);
|
||||
return;
|
||||
}
|
||||
|
||||
struct ConvFwdContext context_;
|
||||
@ -566,6 +683,12 @@ class MklConvOp : public OpKernel {
|
||||
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
|
||||
: TFDataFormatToMklDnn3DDataFormat(data_format_);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
||||
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
|
||||
DCHECK_NE(mkl_fmt_tag, memory::format_tag::undef);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// If input is in MKL layout, then simply grab the layout; otherwise,
|
||||
// construct TF layout for input.
|
||||
// For constructing TF layout for input, although input shape (src_dims)
|
||||
@ -573,18 +696,22 @@ class MklConvOp : public OpKernel {
|
||||
// TF layout depending on the data format:
|
||||
// Conv2D: NHWC or NCHW
|
||||
// Conv3D: NDHWC or NCDHW
|
||||
auto src_md = src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
|
||||
auto src_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
|
||||
#else
|
||||
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
|
||||
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
||||
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for
|
||||
// depthwise/group convolutions.
|
||||
|
||||
auto filter_format = is_conv2d ? (is_depthwise ? memory::format::hwigo
|
||||
: memory::format::hwio)
|
||||
: memory::format::dhwio;
|
||||
auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo
|
||||
: MEMORY_FORMAT::hwio)
|
||||
: MEMORY_FORMAT::dhwio;
|
||||
|
||||
DCHECK(!filter_mkl_shape.IsMklTensor());
|
||||
auto filter_md =
|
||||
@ -593,7 +720,7 @@ class MklConvOp : public OpKernel {
|
||||
: memory::desc(filter_dims, MklDnnType<Tfilter>(), filter_format);
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
|
||||
// MKLDNN dilations start from 0.
|
||||
// MKL-DNN dilations start from 0.
|
||||
for (int i = 0; i < dilations.size(); ++i) --dilations[i];
|
||||
|
||||
// In some cases, primitive descriptor could potentially contain
|
||||
@ -643,10 +770,11 @@ class MklConvOp : public OpKernel {
|
||||
|
||||
// Check whether src and filter need to be reordered
|
||||
Tinput* src_data = nullptr;
|
||||
if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
|
||||
// Reorder src
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
|
||||
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<Tinput*>(
|
||||
@ -654,7 +782,7 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
|
||||
Tfilter* filter_data = nullptr;
|
||||
if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
|
||||
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
|
||||
bool is_filter_cached = false;
|
||||
// If filter is a constant, we can avoid the conversion of filter from
|
||||
// Tensorflow format to MKL format by caching the filter when it is
|
||||
@ -664,21 +792,26 @@ class MklConvOp : public OpKernel {
|
||||
if (IsFilterCacheEmpty(context)) {
|
||||
// Cache filter if it is not already cached.
|
||||
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
filter, filter_md, filter_mkl_shape);
|
||||
#else
|
||||
filter, filter_md);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
filter_data =
|
||||
GetCachedFilter(context, conv_fwd->GetFilterMemoryFormat());
|
||||
filter_data = GetCachedFilter(
|
||||
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
|
||||
is_filter_cached = (filter_data != nullptr);
|
||||
}
|
||||
if (!is_filter_cached) {
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
if (filter_out_tensor == nullptr) {
|
||||
filter.CheckReorderToOpMem(
|
||||
conv_fwd_pd.get()->weights_primitive_desc());
|
||||
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
|
||||
} else {
|
||||
filter.CheckReorderToOpMem(
|
||||
conv_fwd_pd.get()->weights_primitive_desc(),
|
||||
filter.GetTensorBuffer(filter_out_tensor));
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
|
||||
cpu_engine_));
|
||||
}
|
||||
filter_data =
|
||||
static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||
@ -787,7 +920,7 @@ class MklConvOp : public OpKernel {
|
||||
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
|
||||
// checking `fuse_biasadd_` flag.
|
||||
if (fuse_add_) {
|
||||
params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}});
|
||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
|
||||
}
|
||||
if (fuse_activation_) {
|
||||
params.post_op_params.push_back(
|
||||
@ -808,29 +941,35 @@ class MklConvOp : public OpKernel {
|
||||
virtual void AllocateOutputTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
memory::format output_tf_format,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
Tensor** output_tensor) {
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
DCHECK(output_tensor);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto dst_md = conv_prim_desc.dst_desc();
|
||||
#else
|
||||
auto dst_pd = conv_prim_desc.dst_primitive_desc();
|
||||
|
||||
auto dst_md = dst_pd.desc();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
if (!std::is_same<Ttemp_output, Toutput>::value) {
|
||||
dst_md.data.data_type =
|
||||
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
// Allocate shape of Mkl tensor.
|
||||
|
||||
// Allocate shape of MKL tensor
|
||||
MklDnnShape output_mkl_shape;
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&dst_pd);
|
||||
output_mkl_shape.SetMklLayout(&DST_MD);
|
||||
output_mkl_shape.SetElemType(MklDnnType<Toutput>());
|
||||
output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
|
||||
output_dims_mkl_order, output_tf_format);
|
||||
|
||||
// Allocate shape of TF tensor.
|
||||
// Allocate shape of TF tensor
|
||||
TensorShape output_tf_shape;
|
||||
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput)));
|
||||
|
||||
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||
output_tf_shape, output_mkl_shape);
|
||||
if (fuse_add_) {
|
||||
@ -838,32 +977,40 @@ class MklConvOp : public OpKernel {
|
||||
MklDnnShape add_mkl_shape;
|
||||
GetMklShape(context, kInputIndex_Add, &add_mkl_shape);
|
||||
|
||||
// Check if need reorder
|
||||
// Check if reorder is needed
|
||||
if (add_mkl_shape == output_mkl_shape) {
|
||||
CHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
|
||||
DCHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape));
|
||||
} else {
|
||||
auto add_md =
|
||||
add_mkl_shape.IsMklTensor()
|
||||
? add_mkl_shape.GetMklLayout()
|
||||
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
||||
output_mkl_shape.GetTfDataFormat());
|
||||
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
||||
void* add_buf = static_cast<void*>(
|
||||
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
|
||||
void* dst_buf =
|
||||
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
|
||||
auto add = new memory(add_pd, add_buf);
|
||||
auto dst = new memory(dst_pd, dst_buf);
|
||||
auto reorder_desc = mkldnn::reorder::primitive_desc(add_pd, dst_pd);
|
||||
|
||||
std::vector<mkldnn::primitive> net;
|
||||
net.push_back(mkldnn::reorder(reorder_desc, *add, *dst));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
if (add_mkl_shape.IsMklTensor()) {
|
||||
auto add_md = add_mkl_shape.GetMklLayout();
|
||||
} else {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
||||
output_mkl_shape.GetTfDataFormat());
|
||||
DCHECK_NE(output_format_tag, memory::format_tag::undef);
|
||||
auto add_md = memory::desc(output_dims_mkl_order,
|
||||
MklDnnType<Toutput>(), output_format_tag);
|
||||
#else
|
||||
auto add_md =
|
||||
memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
||||
output_mkl_shape.GetTfDataFormat());
|
||||
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
void* add_buf = static_cast<void*>(
|
||||
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
|
||||
void* dst_buf =
|
||||
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
|
||||
auto add = new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
|
||||
auto dst = new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
|
||||
auto reorder_desc =
|
||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||
CreateAndExecuteReorder(reorder_desc, *add, *dst, this->cpu_engine_);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
engine cpu_engine_ = engine(engine::cpu, 0);
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
@ -883,7 +1030,7 @@ class MklConvOp : public OpKernel {
|
||||
bool fuse_add_ = false;
|
||||
|
||||
float relu_up_bound_ = 0.0;
|
||||
mkldnn::algorithm activation_alg_ = mkldnn::algorithm_undef;
|
||||
mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF;
|
||||
|
||||
int input_index_pad_ = 2;
|
||||
|
||||
@ -892,15 +1039,27 @@ class MklConvOp : public OpKernel {
|
||||
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
|
||||
const int kDilationH = 0, kDilationW = 1;
|
||||
|
||||
MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat(
|
||||
const MklDnnShape* filter_mkl_shape,
|
||||
const ConvFwdPd& conv_prim_desc) const {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK(filter_mkl_shape);
|
||||
return filter_mkl_shape->GetTfDataFormat();
|
||||
#else
|
||||
return conv_prim_desc.weights_primitive_desc().desc().data.format;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// Allocate persistent tensors for cached filter data and
|
||||
// cached filter memory descriptor (data format)
|
||||
void AllocatePersistentTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
Tensor** filter_tensor) {
|
||||
Tensor** filter_tensor,
|
||||
const MklDnnShape* filter_mkl_shape) {
|
||||
DCHECK(filter_tensor);
|
||||
TensorShape filter_tf_shape;
|
||||
filter_tf_shape.AddDim(
|
||||
(conv_prim_desc.weights_primitive_desc().get_size() / sizeof(Tfilter)));
|
||||
(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter)));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DataTypeToEnum<Tfilter>::value, filter_tf_shape,
|
||||
&cached_filter_data_ptensor_, filter_tensor));
|
||||
@ -908,37 +1067,44 @@ class MklConvOp : public OpKernel {
|
||||
Tensor* second_tensor = nullptr;
|
||||
TensorShape filter_mkl_format;
|
||||
filter_mkl_format.AddDim(
|
||||
sizeof(conv_prim_desc.weights_primitive_desc().desc().data.format) /
|
||||
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
|
||||
sizeof(DT_INT32));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DT_INT32, filter_mkl_format,
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
second_tensor->scalar<int32>()() =
|
||||
conv_prim_desc.weights_primitive_desc().desc().data.format;
|
||||
second_tensor->scalar<int32>()() = static_cast<int32>(
|
||||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
|
||||
}
|
||||
|
||||
void AllocatePersistentTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
Tensor** filter_tensor) {
|
||||
AllocatePersistentTensor(context, conv_prim_desc, filter_tensor, nullptr);
|
||||
}
|
||||
|
||||
void AllocateFilterOutputTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& filter_dims_tf_order,
|
||||
Tensor** filter_tensor) {
|
||||
CHECK_NOTNULL(filter_tensor);
|
||||
auto filter_pd = conv_prim_desc.weights_primitive_desc();
|
||||
DCHECK(filter_tensor);
|
||||
auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS;
|
||||
|
||||
// Allocate shape of Mkl tensor.
|
||||
// Allocate shape of MKL tensor
|
||||
MklDnnShape filter_mkl_shape;
|
||||
filter_mkl_shape.SetMklTensor(true);
|
||||
filter_mkl_shape.SetMklLayout(&filter_pd);
|
||||
filter_mkl_shape.SetMklLayout(&filter_md);
|
||||
filter_mkl_shape.SetElemType(MklDnnType<Tfilter>());
|
||||
|
||||
// The format of the filter is actually OIhw8i8o, but TF doesn't support
|
||||
// this format. Just use format::blocked for now because the layout
|
||||
// is stored in the MKL data.
|
||||
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
|
||||
filter_dims_tf_order, memory::format::blocked);
|
||||
filter_dims_tf_order,
|
||||
MKL_TENSOR_FORMAT_BLOCKED);
|
||||
|
||||
// Allocate the data space for the filter to propagate as TF tensor.
|
||||
TensorShape filter_tf_shape;
|
||||
filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(Tfilter)));
|
||||
filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter)));
|
||||
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
|
||||
filter_tf_shape, filter_mkl_shape);
|
||||
@ -951,20 +1117,46 @@ class MklConvOp : public OpKernel {
|
||||
MklDnnData<Tbias>* bias,
|
||||
MklDnnData<Toutput>* output,
|
||||
Tensor* filter_out_tensor) {
|
||||
CHECK_NOTNULL(filter_out_tensor);
|
||||
DCHECK(filter_out_tensor);
|
||||
|
||||
// Create reorders between user layout and MKL layout if it is needed and
|
||||
// add it to the net before convolution. No need to check for output
|
||||
// reorder as we propagate output layout to the next layer.
|
||||
src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc());
|
||||
src->CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
|
||||
// rather than re-order to a temp buffer, reorder directly to the
|
||||
// Rather than re-ordering to a temp buffer, reorder directly to the
|
||||
// filter output tensor
|
||||
filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
|
||||
filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS,
|
||||
filter->GetTensorBuffer(filter_out_tensor));
|
||||
|
||||
// Create convolution primitive and add it to net.
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
if (bias) {
|
||||
DCHECK(fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc));
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||
{MKLDNN_ARG_BIAS, bias->GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
output->GetOpMem() }});
|
||||
} else {
|
||||
DCHECK(!fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc));
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
output->GetOpMem() }});
|
||||
}
|
||||
stream cpu_stream(cpu_engine_);
|
||||
DCHECK_EQ(net.size(), net_args.size());
|
||||
for (size_t i = 0; i < net.size(); ++i) {
|
||||
net.at(i).execute(cpu_stream, net_args.at(i));
|
||||
}
|
||||
cpu_stream.wait();
|
||||
#else
|
||||
if (bias) {
|
||||
DCHECK(fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
||||
@ -976,8 +1168,8 @@ class MklConvOp : public OpKernel {
|
||||
filter->GetOpMem(),
|
||||
output->GetOpMem()));
|
||||
}
|
||||
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
|
||||
@ -990,8 +1182,55 @@ class MklConvOp : public OpKernel {
|
||||
return (cached_filter_data_tensor.NumElements() == 0);
|
||||
}
|
||||
|
||||
// Cache the converted filter in a persistent tensor.
|
||||
// Only one thread can execute this method at any given time.
|
||||
// Cache the converted filter in a persistent tensor.
|
||||
// Only one thread can execute this method at any given time.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
void CacheFilter(OpKernelContext* context,
|
||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
||||
MklDnnData<Tfilter>& filter, const memory::desc& filter_md,
|
||||
const MklDnnShape& filter_mkl_shape) LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock lock(mu_);
|
||||
const Tensor& cached_filter_data_tensor =
|
||||
*cached_filter_data_ptensor_.AccessTensor(context);
|
||||
|
||||
// If filter is already cached, there's nothing to do.
|
||||
if (cached_filter_data_tensor.NumElements() > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, cache filter
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
|
||||
this->cpu_engine_);
|
||||
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||
|
||||
Tensor* filter_tensor_ptr = nullptr;
|
||||
AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr,
|
||||
&filter_mkl_shape);
|
||||
void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
|
||||
size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size();
|
||||
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
|
||||
}
|
||||
|
||||
bool AreMemoryDescriptorsEqual(const memory::desc& filter_md,
|
||||
const Tensor& cached_filter_md) {
|
||||
auto filter_md_data = filter_md.data;
|
||||
const char* filter_data = reinterpret_cast<const char*>(&filter_md_data);
|
||||
|
||||
auto cached_filter_md_data = cached_filter_md.scalar<int64>()();
|
||||
const char* cached_filter_data =
|
||||
reinterpret_cast<const char*>(&cached_filter_md_data);
|
||||
|
||||
for (size_t i = 0; i < sizeof(filter_md_data); ++i) {
|
||||
if (*filter_data++ != *cached_filter_data++) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
void CacheFilter(OpKernelContext* context,
|
||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
||||
@ -1018,22 +1257,26 @@ class MklConvOp : public OpKernel {
|
||||
filter.GetOpMem().get_primitive_desc().get_size();
|
||||
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
Tfilter* GetCachedFilter(OpKernelContext* context,
|
||||
const memory::format& filter_mf)
|
||||
LOCKS_EXCLUDED(mu_) {
|
||||
const MEMORY_DESC& filter_md) LOCKS_EXCLUDED(mu_) {
|
||||
tf_shared_lock lock(mu_);
|
||||
const Tensor& cached_filter_data =
|
||||
*cached_filter_data_ptensor_.AccessTensor(context);
|
||||
const Tensor& cached_filter_md =
|
||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_mf. If so, we can used the cached weights; otherwise
|
||||
// return NULL.
|
||||
// TODO (bhavanis): Do we need to cast filter_mf before the check?
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can used the cached weights; otherwise
|
||||
// return NULL.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (cached_filter_md.scalar<int64>().size() &&
|
||||
AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) {
|
||||
#else
|
||||
if (cached_filter_md.scalar<int32>().size() &&
|
||||
cached_filter_md.scalar<int32>()() == filter_mf) {
|
||||
cached_filter_md.scalar<int32>()() == filter_md) {
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
return static_cast<Tfilter*>(
|
||||
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
|
||||
}
|
||||
@ -1069,26 +1312,26 @@ class MklFusedConvOp
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"Relu"}) {
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_relu);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
} else if (fused_ops == std::vector<string>{"Relu6"}) {
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
} else if (fused_ops == std::vector<string>{"Elu"}) {
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_elu);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_relu);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_elu);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
@ -1102,7 +1345,7 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_relu);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1110,7 +1353,7 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1118,7 +1361,7 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, mkldnn::eltwise_elu);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1274,7 +1517,7 @@ class MklQuantizedConv2DOp
|
||||
(255.0f * 127.0f * output_range);
|
||||
}
|
||||
params.post_op_params.push_back(
|
||||
{"output_scale", mkldnn::algorithm_undef, scales});
|
||||
{"output_scale", ALGORITHM_UNDEF, scales});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1293,7 +1536,6 @@ class MklQuantizedConv2DOp
|
||||
const float* min_filter = min_filter_vector.flat<float>().data();
|
||||
const float* max_filter = max_filter_vector.flat<float>().data();
|
||||
|
||||
std::vector<mkldnn::primitive> net;
|
||||
if (bias_enabled) {
|
||||
if (std::is_same<Tbias, qint32>::value) {
|
||||
return static_cast<Tbias*>(
|
||||
@ -1315,21 +1557,21 @@ class MklQuantizedConv2DOp
|
||||
} else {
|
||||
bias_attr.set_output_scales(1, scales);
|
||||
}
|
||||
auto bias_pd =
|
||||
memory::primitive_desc({{static_cast<int>(bias_tensor.NumElements())},
|
||||
MklDnnType<Tbias>(),
|
||||
memory::format::x},
|
||||
this->cpu_engine_);
|
||||
|
||||
auto bias_md =
|
||||
MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()),
|
||||
Tbias, x, this->cpu_engine_);
|
||||
void* bias_buf = static_cast<void*>(
|
||||
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
||||
input_bias_ = new memory(bias_pd, bias_buf);
|
||||
scaled_bias_ = new memory(conv_fwd_pd->bias_primitive_desc());
|
||||
auto reorder_desc = mkldnn::reorder::primitive_desc(
|
||||
input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(),
|
||||
input_bias_ =
|
||||
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
|
||||
scaled_bias_ = new MEMORY_CONSTRUCTOR_WITHOUT_DATA(
|
||||
conv_fwd_pd->PRIMITIVE_DESC_BIAS, this->cpu_engine_);
|
||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
||||
bias_attr);
|
||||
net.push_back(mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
||||
this->cpu_engine_);
|
||||
return reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
|
||||
} else {
|
||||
return nullptr;
|
||||
@ -1358,7 +1600,7 @@ class MklQuantizedConv2DReluOp
|
||||
MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, bias_enabled,
|
||||
is_depthwise>::ExtendConvFwdParams(context, params);
|
||||
params.post_op_params.push_back(
|
||||
{"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}});
|
||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
|
||||
}
|
||||
};
|
||||
|
||||
@ -1415,23 +1657,23 @@ class MklQuantizedConv2DSumReluOp
|
||||
// If it is not then it is DT_INT8 and is scaled appropriately.
|
||||
if (summand_type == DT_QUINT8)
|
||||
params.post_op_params.push_back(
|
||||
{"sum", mkldnn::algorithm_undef, {scale_summand / scale_output}});
|
||||
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}});
|
||||
else
|
||||
params.post_op_params.push_back(
|
||||
{"sum",
|
||||
mkldnn::algorithm_undef,
|
||||
ALGORITHM_UNDEF,
|
||||
{255.0f * scale_summand / (scale_output * 127.0f)}});
|
||||
} else {
|
||||
params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}});
|
||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}});
|
||||
}
|
||||
params.post_op_params.push_back(
|
||||
{"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}});
|
||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
|
||||
}
|
||||
|
||||
void AllocateOutputTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
memory::format output_tf_format,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
Tensor** output_tensor) override {
|
||||
int summand_idx = context->num_inputs() / 2 - 1;
|
||||
if (std::is_same<Toutput, quint8>::value) {
|
||||
@ -1503,20 +1745,22 @@ class MklQuantizedConv2DSumReluOp
|
||||
summand_mkl_shape.IsMklTensor()
|
||||
? summand_mkl_shape.GetMklLayout()
|
||||
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
|
||||
memory::format::nhwc);
|
||||
MEMORY_FORMAT::nhwc);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
void* summand_buf =
|
||||
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
|
||||
void* dst_buf =
|
||||
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
|
||||
summand_ = new memory(summand_pd, summand_buf);
|
||||
dst_ = new memory(conv_prim_desc.dst_primitive_desc(), dst_buf);
|
||||
auto reorder_desc = mkldnn::reorder::primitive_desc(
|
||||
summand_pd, conv_prim_desc.dst_primitive_desc(), reorder_attr);
|
||||
|
||||
std::vector<mkldnn::primitive> net;
|
||||
net.push_back(mkldnn::reorder(reorder_desc, *summand_, *dst_));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
summand_ =
|
||||
new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf);
|
||||
dst_ = new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST,
|
||||
this->cpu_engine_, dst_buf);
|
||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
|
||||
reorder_attr);
|
||||
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_);
|
||||
}
|
||||
|
||||
memory* summand_ = nullptr;
|
||||
@ -1970,5 +2214,36 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED);
|
||||
TF_CALL_float(REGISTER_MKL_CPU_3D);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU_3D);
|
||||
|
||||
#undef ADD_MD
|
||||
#undef ALGORITHM
|
||||
#undef ALGORITHM_UNDEF
|
||||
#undef CPU_STREAM
|
||||
#undef DATA_WITH_ENGINE
|
||||
#undef DST_MD
|
||||
#undef ENGINE_CPU
|
||||
#undef GET_DESC
|
||||
#undef GET_MEMORY_DESC_CONSTRUCTOR
|
||||
#undef GET_SRC_DESC_FROM_OP_PD
|
||||
#undef GET_WEIGHTS_DESC_FROM_OP_PD
|
||||
#undef GET_WEIGHTS_FORMAT_FROM_OP_PD
|
||||
#undef IS_FILTER_REORDER_NEEDED
|
||||
#undef IS_SRC_REORDER_NEEDED
|
||||
#undef MEMORY_CONSTRUCTOR
|
||||
#undef MEMORY_CONSTRUCTOR_USING_MEM_PD
|
||||
#undef MEMORY_CONSTRUCTOR_WITHOUT_DATA
|
||||
#undef MEMORY_DESC
|
||||
#undef MEMORY_FORMAT
|
||||
#undef MEMORY_PD_CONSTRUCTOR
|
||||
#undef MEMORY_PD_WITHOUT_DATA
|
||||
#undef MKL_TENSOR_FORMAT
|
||||
#undef MKL_TENSOR_FORMAT_BLOCKED
|
||||
#undef MKL_TENSOR_FORMAT_IN_C
|
||||
#undef PRIMITIVE_DESC_BIAS
|
||||
#undef PRIMITIVE_DESC_DST
|
||||
#undef PRIMITIVE_DESC_SRC
|
||||
#undef PRIMITIVE_DESC_WEIGHTS
|
||||
#undef REORDER_PD_CONSTRUCTOR
|
||||
#undef REORDER_PD_CONSTRUCTOR_WITH_ATTR
|
||||
#undef SUMMAND_MD
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
@ -40,13 +40,21 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::convolution_direct;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define MKLDNN_SIZE_DTYPE long int
|
||||
#else
|
||||
#define MKLDNN_SIZE_DTYPE int
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
class MklDnnConvUtil {
|
||||
protected:
|
||||
OpKernelContext* context_; // We don't own this.
|
||||
@ -137,7 +145,7 @@ class MklDnnConvUtil {
|
||||
int input_cols = static_cast<int>(input_cols_raw);
|
||||
|
||||
// MKL-DNN always requires input in NCHW format Conv2D.
|
||||
std::vector<int> mkldnn_sizes(4, -1);
|
||||
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
|
||||
mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
|
||||
mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
|
||||
mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
|
||||
@ -161,7 +169,7 @@ class MklDnnConvUtil {
|
||||
int input_cols = static_cast<int>(input_cols_raw);
|
||||
|
||||
// MKL-DNN always requires input in NCDHW format for Conv3D.
|
||||
std::vector<int> mkldnn_sizes(5, -1);
|
||||
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
|
||||
@ -225,7 +233,7 @@ class MklDnnConvUtil {
|
||||
// GOIHW = (group, out_depth, in_depth, rows, cols)
|
||||
// Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1
|
||||
if (is_depthwise) {
|
||||
std::vector<int> mkldnn_sizes(5, -1);
|
||||
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
|
||||
mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth;
|
||||
mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth;
|
||||
mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1;
|
||||
@ -234,7 +242,7 @@ class MklDnnConvUtil {
|
||||
|
||||
*filter_dims = mkldnn_sizes;
|
||||
} else {
|
||||
std::vector<int> mkldnn_sizes(4, -1);
|
||||
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
|
||||
mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth;
|
||||
mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth;
|
||||
mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
|
||||
@ -262,7 +270,7 @@ class MklDnnConvUtil {
|
||||
|
||||
// MKL-DNN always needs filter in OIDHW format.
|
||||
// OIDHW = (out_depth, in_depth, planes, rows, cols)
|
||||
std::vector<int> mkldnn_sizes(5, -1);
|
||||
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth;
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth;
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
|
||||
@ -455,14 +463,14 @@ class MklDnnConvUtil {
|
||||
|
||||
if (is_conv2d) {
|
||||
// For Conv2D, MKL-DNN always needs output in NCHW format.
|
||||
std::vector<int> mkldnn_sizes(4, -1);
|
||||
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(4, -1);
|
||||
mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
|
||||
mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
|
||||
mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
|
||||
mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
|
||||
*output_dims_mkl_order = mkldnn_sizes;
|
||||
} else {
|
||||
std::vector<int> mkldnn_sizes(5, -1);
|
||||
std::vector<MKLDNN_SIZE_DTYPE> mkldnn_sizes(5, -1);
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
|
||||
mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
|
||||
@ -624,6 +632,8 @@ class MklDummyOp : public OpKernel {
|
||||
}
|
||||
};
|
||||
|
||||
#undef MKLDNN_SIZE_DTYPE
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
|
||||
|
@ -1194,6 +1194,27 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
|
||||
return memory::desc(md);
|
||||
}
|
||||
|
||||
inline void CreateAndExecuteReorder(const reorder::primitive_desc& reorder_desc,
|
||||
const memory& src_mem,
|
||||
const memory& dst_mem,
|
||||
const engine& engine) {
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
net.push_back(mkldnn::reorder(reorder_desc));
|
||||
std::vector<MemoryArgsMap> net_args;
|
||||
net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}});
|
||||
DCHECK_EQ(net.size(), net_args.size());
|
||||
stream cpu_stream(engine);
|
||||
for (size_t i = 0; i < net.size(); ++i) {
|
||||
net.at(i).execute(cpu_stream, net_args.at(i));
|
||||
}
|
||||
cpu_stream.wait();
|
||||
#else
|
||||
net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem));
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline primitive FindOrCreateReorder(const memory* from, const memory* to);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user