Merge pull request #44823 from Intel-tensorflow:dnn0x_cleanup_fwd_conv
PiperOrigin-RevId: 350041167 Change-Id: Ic11ff723860813c70d5ad0427e5672da77867246
This commit is contained in:
commit
6876288ead
@ -43,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -65,7 +64,7 @@ struct MklConvFwdParams {
|
||||
memory::dims dilations;
|
||||
memory::dims padding_left;
|
||||
memory::dims padding_right;
|
||||
MKL_TENSOR_FORMAT tf_fmt;
|
||||
MklTensorFormat tf_fmt;
|
||||
bool native_format;
|
||||
string dtypes = string("");
|
||||
struct PostOpParam {
|
||||
@ -80,7 +79,7 @@ struct MklConvFwdParams {
|
||||
memory::dims bias_dims, memory::dims dst_dims,
|
||||
memory::dims strides, memory::dims dilations,
|
||||
memory::dims padding_left, memory::dims padding_right,
|
||||
MKL_TENSOR_FORMAT tf_fmt, bool native_format)
|
||||
MklTensorFormat tf_fmt, bool native_format)
|
||||
: src_dims(src_dims),
|
||||
filter_dims(filter_dims),
|
||||
bias_dims(bias_dims),
|
||||
@ -99,7 +98,7 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
|
||||
class MklConvFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
: MklPrimitive(engine(engine::kind::cpu, 0)) {
|
||||
// Create convolution primitive
|
||||
if (context_.conv_fwd == nullptr) {
|
||||
Setup(convFwdDims);
|
||||
@ -115,8 +114,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||
const Tbias* bias_data, const Toutput* dst_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
||||
context_.filter_mem->set_data_handle(
|
||||
@ -139,16 +138,13 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||
context_.fwd_primitives.at(i).execute(*fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
}
|
||||
#else
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// After execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
@ -168,13 +164,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// In MKL-DNN v1.x, memory format tags only provide a partial description
|
||||
// of the memory layout. Hence, these functions are disabled for v1.x.
|
||||
memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
|
||||
memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; }
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
std::shared_ptr<ConvFwdPd> GetPrimitiveDesc() const {
|
||||
return context_.fwd_pd;
|
||||
}
|
||||
@ -182,12 +171,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
private:
|
||||
// Primitive reuse context for Conv2D Fwd op
|
||||
struct ConvFwdContext {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Expected memory format for this primitive instance
|
||||
memory::format src_fmt;
|
||||
memory::format filter_fmt;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// MKL-DNN memory
|
||||
std::shared_ptr<mkldnn::memory> src_mem;
|
||||
std::shared_ptr<mkldnn::memory> filter_mem;
|
||||
@ -208,18 +191,10 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
||||
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
ConvFwdContext()
|
||||
:
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_fmt(memory::format::any),
|
||||
filter_fmt(memory::format::any),
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
src_mem(nullptr),
|
||||
: src_mem(nullptr),
|
||||
filter_mem(nullptr),
|
||||
bias_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
@ -228,52 +203,45 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
filter_md(nullptr),
|
||||
bias_md(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
conv_fwd(nullptr) {
|
||||
}
|
||||
conv_fwd(nullptr) {}
|
||||
};
|
||||
|
||||
void Setup(const MklConvFwdParams& convFwdDims) {
|
||||
MEMORY_FORMAT user_data_fmt;
|
||||
memory::format_tag user_data_fmt;
|
||||
if (convFwdDims.native_format) {
|
||||
user_data_fmt = MklTensorFormatToMklDnnDataFormat(convFwdDims.tf_fmt);
|
||||
} else {
|
||||
// Create memory descriptors for convolution data w/ no specified format
|
||||
user_data_fmt = MEMORY_FORMAT::any;
|
||||
user_data_fmt = memory::format_tag::any;
|
||||
}
|
||||
context_.src_md.reset(new memory::desc(
|
||||
{convFwdDims.src_dims}, MklDnnType<Tinput>(), user_data_fmt));
|
||||
|
||||
context_.filter_md.reset(new memory::desc(
|
||||
{convFwdDims.filter_dims}, MklDnnType<Tfilter>(), MEMORY_FORMAT::any));
|
||||
context_.filter_md.reset(new memory::desc({convFwdDims.filter_dims},
|
||||
MklDnnType<Tfilter>(),
|
||||
memory::format_tag::any));
|
||||
|
||||
context_.dst_md.reset(new memory::desc(
|
||||
{convFwdDims.dst_dims}, MklDnnType<Toutput>(), user_data_fmt));
|
||||
|
||||
if (!convFwdDims.bias_dims.empty())
|
||||
context_.bias_md.reset(new memory::desc(
|
||||
{convFwdDims.bias_dims}, MklDnnType<Tbias>(), MEMORY_FORMAT::any));
|
||||
context_.bias_md.reset(new memory::desc({convFwdDims.bias_dims},
|
||||
MklDnnType<Tbias>(),
|
||||
memory::format_tag::any));
|
||||
|
||||
// Create a convolution descriptor
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.bias_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
prop_kind::forward, mkldnn::algorithm::convolution_direct,
|
||||
*context_.src_md, *context_.filter_md, *context_.bias_md,
|
||||
*context_.dst_md, convFwdDims.strides, convFwdDims.dilations,
|
||||
convFwdDims.padding_left, convFwdDims.padding_right));
|
||||
} else {
|
||||
context_.fwd_desc.reset(new convolution_forward::desc(
|
||||
prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md,
|
||||
*context_.filter_md, *context_.dst_md, convFwdDims.strides,
|
||||
convFwdDims.dilations, convFwdDims.padding_left,
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
convFwdDims.padding_right, padding_kind::zero));
|
||||
#else
|
||||
prop_kind::forward, mkldnn::algorithm::convolution_direct,
|
||||
*context_.src_md, *context_.filter_md, *context_.dst_md,
|
||||
convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left,
|
||||
convFwdDims.padding_right));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
@ -314,54 +282,32 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Store the expected memory format
|
||||
context_.src_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->src_primitive_desc().desc().data.format);
|
||||
|
||||
context_.filter_fmt = static_cast<mkldnn::memory::format>(
|
||||
context_.fwd_pd.get()->weights_primitive_desc().desc().data.format);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData));
|
||||
context_.filter_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
|
||||
context_.src_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData));
|
||||
context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(),
|
||||
cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData));
|
||||
|
||||
// Create convolution primitive and add it to net
|
||||
if (!convFwdDims.bias_dims.empty()) {
|
||||
context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD(
|
||||
convFwdDims.bias_dims, Tbias, MEMORY_FORMAT::x, cpu_engine_,
|
||||
DummyData));
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.bias_mem.reset(new memory(
|
||||
{{convFwdDims.bias_dims}, MklDnnType<Tbias>(), memory::format_tag::x},
|
||||
cpu_engine_, DummyData));
|
||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||
context_.fwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||
{MKLDNN_ARG_BIAS, *context_.bias_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||
} else {
|
||||
context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd));
|
||||
context_.fwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_WEIGHTS, *context_.filter_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
{MKLDNN_ARG_DST, *context_.dst_mem}});
|
||||
}
|
||||
#else
|
||||
context_.conv_fwd.reset(new convolution_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.filter_mem,
|
||||
*context_.bias_mem, *context_.dst_mem));
|
||||
} else {
|
||||
context_.conv_fwd.reset(
|
||||
new convolution_forward(*context_.fwd_pd, *context_.src_mem,
|
||||
*context_.filter_mem, *context_.dst_mem));
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
context_.fwd_primitives.push_back(*context_.conv_fwd);
|
||||
}
|
||||
|
||||
@ -650,12 +596,10 @@ class MklConvOp : public OpKernel {
|
||||
auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_)
|
||||
: TFDataFormatToMklDnn3DDataFormat(data_format_);
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt);
|
||||
// NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU
|
||||
OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef,
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// If input is in MKL layout, then simply grab the layout; otherwise,
|
||||
// construct TF layout for input.
|
||||
@ -667,19 +611,15 @@ class MklConvOp : public OpKernel {
|
||||
auto src_md =
|
||||
src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
: memory::desc(src_dims, MklDnnType<Tinput>(), mkl_fmt_tag);
|
||||
#else
|
||||
: memory::desc(src_dims, MklDnnType<Tinput>(), tf_fmt);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
|
||||
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
||||
// the layout is Tensorflow's layout (HWIO) and (HWIGO) for
|
||||
// depthwise/group convolutions.
|
||||
auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo
|
||||
: MEMORY_FORMAT::hwio)
|
||||
: MEMORY_FORMAT::dhwio;
|
||||
auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo
|
||||
: memory::format_tag::hwio)
|
||||
: memory::format_tag::dhwio;
|
||||
|
||||
DCHECK(!filter_mkl_shape.IsMklTensor());
|
||||
auto filter_md =
|
||||
@ -738,12 +678,9 @@ class MklConvOp : public OpKernel {
|
||||
|
||||
// Check whether src and filter need to be reordered.
|
||||
Tinput* src_data = nullptr;
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
|
||||
if (src_md != conv_fwd_pd->src_desc()) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_, context);
|
||||
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<Tinput*>(
|
||||
@ -751,7 +688,7 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
|
||||
Tfilter* filter_data = nullptr;
|
||||
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
|
||||
if (filter_md != conv_fwd_pd->weights_desc()) {
|
||||
bool is_filter_cached = false;
|
||||
// If filter is a constant, we can avoid the conversion of filter from
|
||||
// Tensorflow format to MKL format by caching the filter when it is
|
||||
@ -761,28 +698,20 @@ class MklConvOp : public OpKernel {
|
||||
if (IsFilterCacheEmpty(context)) {
|
||||
// Cache filter if it is not already cached.
|
||||
CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor,
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
filter, filter_md, filter_mkl_shape);
|
||||
#else
|
||||
filter, filter_md);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
filter_data = GetCachedFilter(
|
||||
context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd));
|
||||
filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc());
|
||||
is_filter_cached = (filter_data != nullptr);
|
||||
}
|
||||
if (!is_filter_cached) {
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
if (filter_out_tensor == nullptr) {
|
||||
filter.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), cpu_engine_,
|
||||
context);
|
||||
} else {
|
||||
filter.CheckReorderToOpMem(
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
|
||||
cpu_engine_),
|
||||
conv_fwd_pd->weights_desc(),
|
||||
filter.GetTensorBuffer(filter_out_tensor), cpu_engine_,
|
||||
context);
|
||||
}
|
||||
filter_data =
|
||||
@ -897,7 +826,8 @@ class MklConvOp : public OpKernel {
|
||||
// NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by
|
||||
// checking `fuse_biasadd_` flag.
|
||||
if (fuse_add_) {
|
||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
|
||||
params.post_op_params.push_back(
|
||||
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
|
||||
}
|
||||
if (fuse_activation_) {
|
||||
params.post_op_params.push_back(
|
||||
@ -918,35 +848,27 @@ class MklConvOp : public OpKernel {
|
||||
virtual void AllocateOutputTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
MklTensorFormat output_tf_format,
|
||||
MklDnnShape* output_mkl_shape,
|
||||
Tensor** output_tensor) {
|
||||
DCHECK(output_tensor);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto dst_md = conv_prim_desc.dst_desc();
|
||||
#else
|
||||
auto dst_pd = conv_prim_desc.dst_primitive_desc();
|
||||
auto dst_md = dst_pd.desc();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
if (!std::is_same<Ttemp_output, Toutput>::value) {
|
||||
dst_md.data.data_type =
|
||||
static_cast<mkldnn_data_type_t>(MklDnnType<Toutput>());
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
dst_pd = memory::primitive_desc(dst_md, cpu_engine_);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// Allocate shape of MKL tensor
|
||||
output_mkl_shape->SetMklTensor(true);
|
||||
output_mkl_shape->SetMklLayout(&DST_MD);
|
||||
output_mkl_shape->SetMklLayout(&dst_md);
|
||||
output_mkl_shape->SetElemType(MklDnnType<Toutput>());
|
||||
output_mkl_shape->SetTfLayout(output_dims_mkl_order.size(),
|
||||
output_dims_mkl_order, output_tf_format);
|
||||
|
||||
// Allocate shape of TF tensor
|
||||
TensorShape output_tf_shape;
|
||||
output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput)));
|
||||
output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput)));
|
||||
if (native_format) {
|
||||
output_tf_shape = output_mkl_shape->GetTfShape();
|
||||
}
|
||||
@ -972,23 +894,16 @@ class MklConvOp : public OpKernel {
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||
output_tf_shape, *output_mkl_shape,
|
||||
native_format);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
auto output_format_tag = MklTensorFormatToMklDnnDataFormat(
|
||||
output_mkl_shape->GetTfDataFormat());
|
||||
OP_REQUIRES(context, output_format_tag != memory::format_tag::undef,
|
||||
errors::InvalidArgument(
|
||||
"MklConvOp: AddN fusion: Invalid data format"));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
auto add_md =
|
||||
add_mkl_shape.IsMklTensor()
|
||||
? add_mkl_shape.GetMklLayout()
|
||||
: memory::desc(output_dims_mkl_order, MklDnnType<Toutput>(),
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
output_format_tag);
|
||||
#else
|
||||
output_mkl_shape->GetTfDataFormat());
|
||||
auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
void* add_buf = static_cast<void*>(
|
||||
const_cast<Toutput*>(add_tensor.flat<Toutput>().data()));
|
||||
void* dst_buf =
|
||||
@ -996,16 +911,14 @@ class MklConvOp : public OpKernel {
|
||||
if (native_format) {
|
||||
// We are simply deep copying the add_tensor to output_tensor without
|
||||
// changing memory layout, hence using same memory descriptor.
|
||||
ADD_MD = DST_MD =
|
||||
add_md = dst_md =
|
||||
memory::desc({add_tensor.NumElements()}, MklDnnType<Toutput>(),
|
||||
mkldnn::memory::format_tag::x);
|
||||
}
|
||||
fuse_add_src_.reset(
|
||||
new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf));
|
||||
fuse_add_dst_.reset(
|
||||
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
|
||||
fuse_add_src_.reset(new memory(add_md, this->cpu_engine_, add_buf));
|
||||
fuse_add_dst_.reset(new memory(dst_md, this->cpu_engine_, dst_buf));
|
||||
auto reorder_desc =
|
||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||
ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
|
||||
|
||||
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
|
||||
this->cpu_engine_, context);
|
||||
@ -1017,7 +930,7 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
engine cpu_engine_ = engine(engine::kind::cpu, 0);
|
||||
|
||||
private:
|
||||
std::shared_ptr<mkldnn::memory> fuse_add_src_;
|
||||
@ -1041,7 +954,7 @@ class MklConvOp : public OpKernel {
|
||||
// This variable is used for alpha in leakyrelu or upper bound in relu6
|
||||
// depending on the context
|
||||
float alpha_or_upbound_ = 0.0;
|
||||
mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF;
|
||||
mkldnn::algorithm activation_alg_ = mkldnn::algorithm::undef;
|
||||
|
||||
int input_index_pad_ = 2;
|
||||
|
||||
@ -1050,15 +963,10 @@ class MklConvOp : public OpKernel {
|
||||
const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
|
||||
const int kDilationH = 0, kDilationW = 1;
|
||||
|
||||
MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat(
|
||||
const MklDnnShape* filter_mkl_shape,
|
||||
const ConvFwdPd& conv_prim_desc) const {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
MklTensorFormat GetFilterTfDataFormat(const MklDnnShape* filter_mkl_shape,
|
||||
const ConvFwdPd& conv_prim_desc) const {
|
||||
DCHECK(filter_mkl_shape);
|
||||
return filter_mkl_shape->GetTfDataFormat();
|
||||
#else
|
||||
return conv_prim_desc.weights_primitive_desc().desc().data.format;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// Allocate persistent tensors for cached filter data and
|
||||
@ -1070,23 +978,13 @@ class MklConvOp : public OpKernel {
|
||||
DCHECK(filter_tensor);
|
||||
TensorShape filter_tf_shape;
|
||||
filter_tf_shape.AddDim(
|
||||
(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter)));
|
||||
(conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter)));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DataTypeToEnum<Tfilter>::value, filter_tf_shape,
|
||||
&cached_filter_data_ptensor_, filter_tensor));
|
||||
|
||||
Tensor* second_tensor = nullptr;
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
TensorShape filter_mkl_format;
|
||||
filter_mkl_format.AddDim(
|
||||
sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) /
|
||||
sizeof(DT_INT32));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DT_INT32, filter_mkl_format,
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
second_tensor->scalar<int32>()() = static_cast<int32>(
|
||||
GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc));
|
||||
#else
|
||||
|
||||
// There is no tensor format in DNNL 1.x. So we cache the complete filter
|
||||
// descriptor as flat byte array.
|
||||
TensorShape cached_filter_md_shape;
|
||||
@ -1100,7 +998,6 @@ class MklConvOp : public OpKernel {
|
||||
&cached_filter_md_ptensor_, &second_tensor));
|
||||
*reinterpret_cast<memory::desc*>(second_tensor->flat<uint8>().data()) =
|
||||
weights_desc;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
void AllocatePersistentTensor(OpKernelContext* context,
|
||||
@ -1114,7 +1011,7 @@ class MklConvOp : public OpKernel {
|
||||
const memory::dims& filter_dims_tf_order,
|
||||
Tensor** filter_tensor) {
|
||||
DCHECK(filter_tensor);
|
||||
auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS;
|
||||
auto filter_md = conv_prim_desc.weights_desc();
|
||||
|
||||
// Allocate shape of MKL tensor
|
||||
MklDnnShape filter_mkl_shape;
|
||||
@ -1127,7 +1024,7 @@ class MklConvOp : public OpKernel {
|
||||
// is stored in the MKL data.
|
||||
filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
|
||||
filter_dims_tf_order,
|
||||
MKL_TENSOR_FORMAT_BLOCKED);
|
||||
MklTensorFormat::FORMAT_BLOCKED);
|
||||
|
||||
// Allocate the data space for the filter to propagate as TF tensor.
|
||||
TensorShape filter_tf_shape;
|
||||
@ -1150,17 +1047,15 @@ class MklConvOp : public OpKernel {
|
||||
// Create reorders between user layout and MKL layout if it is needed and
|
||||
// add it to the net before convolution. No need to check for output
|
||||
// reorder as we propagate output layout to the next layer.
|
||||
src->CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_);
|
||||
|
||||
// Rather than re-ordering to a temp buffer, reorder directly to the
|
||||
// filter output tensor
|
||||
filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS,
|
||||
filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(),
|
||||
filter->GetTensorBuffer(filter_out_tensor));
|
||||
|
||||
// Create convolution primitive and add it to net.
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
if (bias) {
|
||||
DCHECK(fuse_biasadd_);
|
||||
@ -1168,31 +1063,15 @@ class MklConvOp : public OpKernel {
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||
{MKLDNN_ARG_BIAS, bias->GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
output->GetOpMem() }});
|
||||
{MKLDNN_ARG_DST, output->GetOpMem()}});
|
||||
} else {
|
||||
DCHECK(!fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc));
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()},
|
||||
{MKLDNN_ARG_WEIGHTS, filter->GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
output->GetOpMem() }});
|
||||
{MKLDNN_ARG_DST, output->GetOpMem()}});
|
||||
}
|
||||
ExecutePrimitive(net, &net_args, cpu_engine_);
|
||||
#else
|
||||
if (bias) {
|
||||
DCHECK(fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
||||
filter->GetOpMem(), bias->GetOpMem(),
|
||||
output->GetOpMem()));
|
||||
} else {
|
||||
DCHECK(!fuse_biasadd_);
|
||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
||||
filter->GetOpMem(),
|
||||
output->GetOpMem()));
|
||||
}
|
||||
ExecutePrimitive(net, nullptr, cpu_engine_);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
// TF_LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
|
||||
@ -1206,9 +1085,8 @@ class MklConvOp : public OpKernel {
|
||||
return (cached_filter_data_tensor.NumElements() == 0);
|
||||
}
|
||||
|
||||
// Cache the converted filter in a persistent tensor.
|
||||
// Only one thread can execute this method at any given time.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Cache the converted filter in a persistent tensor.
|
||||
// Only one thread can execute this method at any given time.
|
||||
void CacheFilter(OpKernelContext* context,
|
||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
||||
@ -1254,37 +1132,8 @@ class MklConvOp : public OpKernel {
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
void CacheFilter(OpKernelContext* context,
|
||||
const std::shared_ptr<ConvFwdPd>& conv_fwd_pd,
|
||||
Tfilter* filter_data, const Tensor& filter_tensor,
|
||||
MklDnnData<Tfilter>& filter, const memory::desc& filter_md)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock lock(mu_);
|
||||
const Tensor& cached_filter_data_tensor =
|
||||
*cached_filter_data_ptensor_.AccessTensor(context);
|
||||
|
||||
// If filter is already cached, there's nothing to do.
|
||||
if (cached_filter_data_tensor.NumElements() > 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Otherwise, cache filter
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc());
|
||||
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||
|
||||
Tensor* filter_tensor_ptr = nullptr;
|
||||
AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr);
|
||||
void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr);
|
||||
size_t cached_filter_data_size =
|
||||
filter.GetOpMem().get_primitive_desc().get_size();
|
||||
memcpy(cached_filter_data, filter_data, cached_filter_data_size);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
Tfilter* GetCachedFilter(OpKernelContext* context,
|
||||
const MEMORY_DESC& filter_md)
|
||||
const memory::desc& filter_md)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
tf_shared_lock lock(mu_);
|
||||
const Tensor& cached_filter_data =
|
||||
@ -1292,15 +1141,10 @@ class MklConvOp : public OpKernel {
|
||||
const Tensor& cached_filter_md =
|
||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Check if the memory descriptor of the cached weights is the same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||
#else
|
||||
if (cached_filter_md.scalar<int32>().size() &&
|
||||
cached_filter_md.scalar<int32>()() == filter_md) {
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
return static_cast<Tfilter*>(
|
||||
const_cast<Tfilter*>(cached_filter_data.flat<Tfilter>().data()));
|
||||
}
|
||||
@ -1336,31 +1180,34 @@ class MklFusedConvOp
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"Relu"}) {
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
} else if (fused_ops == std::vector<string>{"Relu6"}) {
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
} else if (fused_ops == std::vector<string>{"Elu"}) {
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
} else if (fused_ops == std::vector<string>{"LeakyRelu"}) {
|
||||
float leakyrelu_alpha;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||
leakyrelu_alpha);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
@ -1369,7 +1216,8 @@ class MklFusedConvOp
|
||||
float leakyrelu_alpha;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||
leakyrelu_alpha);
|
||||
OP_REQUIRES(context, num_args == 1,
|
||||
errors::InvalidArgument(
|
||||
"Fused Conv2D must have one extra argument: bias."));
|
||||
@ -1383,7 +1231,7 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1391,7 +1239,8 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1399,7 +1248,7 @@ class MklFusedConvOp
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Add", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_add(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1411,7 +1260,8 @@ class MklFusedConvOp
|
||||
float leakyrelu_alpha;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("leakyrelu_alpha", &leakyrelu_alpha));
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu, leakyrelu_alpha);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu,
|
||||
leakyrelu_alpha);
|
||||
OP_REQUIRES(
|
||||
context, num_args == 2,
|
||||
errors::InvalidArgument(
|
||||
@ -1459,13 +1309,14 @@ class MklFusedDepthwiseConvOp
|
||||
this->set_fuse_biasadd(true);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_relu);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_relu);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Relu6"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_bounded_relu,
|
||||
6.0);
|
||||
} else if (fused_ops == std::vector<string>{"BiasAdd", "Elu"}) {
|
||||
this->set_fuse_biasadd(true);
|
||||
this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0);
|
||||
this->set_fuse_activation(true, mkldnn::algorithm::eltwise_elu, 1.0);
|
||||
} else {
|
||||
OP_REQUIRES(context, false,
|
||||
errors::Unimplemented("Fusion is not implemented: [",
|
||||
@ -1642,8 +1493,8 @@ class MklQuantizedConv2DOp
|
||||
param_key.AddAsKey<float>(max_freezed_output);
|
||||
param_key.AddAsKey<const float*>(min_filter);
|
||||
param_key.AddAsKey<const float*>(max_filter);
|
||||
params.post_op_params.push_back(
|
||||
{"output_scale", ALGORITHM_UNDEF, scales, param_key.GetKey()});
|
||||
params.post_op_params.push_back({"output_scale", mkldnn::algorithm::undef,
|
||||
scales, param_key.GetKey()});
|
||||
}
|
||||
}
|
||||
|
||||
@ -1696,31 +1547,27 @@ class MklQuantizedConv2DOp
|
||||
bias_attr.set_output_scales(1, scales_);
|
||||
}
|
||||
|
||||
auto bias_md =
|
||||
MEMORY_PD_CONSTRUCTOR(static_cast<int>(bias_tensor.NumElements()),
|
||||
Tbias, MEMORY_FORMAT::x, this->cpu_engine_);
|
||||
auto bias_md = memory::desc({static_cast<int>(bias_tensor.NumElements())},
|
||||
MklDnnType<Tbias>(), memory::format_tag::x);
|
||||
void* bias_buf = static_cast<void*>(
|
||||
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
|
||||
if (!input_bias_) {
|
||||
input_bias_ =
|
||||
new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf);
|
||||
input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf);
|
||||
} else {
|
||||
input_bias_->set_data_handle(bias_buf);
|
||||
}
|
||||
|
||||
if (!scaled_bias_buf_)
|
||||
AllocTmpBuffer<Tbias>(context, &scaled_bias_tensor_,
|
||||
GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
&scaled_bias_buf_);
|
||||
conv_fwd_pd->bias_desc(), &scaled_bias_buf_);
|
||||
if (!scaled_bias_) {
|
||||
scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_,
|
||||
scaled_bias_buf_);
|
||||
scaled_bias_ = new memory(bias_md, this->cpu_engine_, scaled_bias_buf_);
|
||||
} else {
|
||||
scaled_bias_->set_data_handle(scaled_bias_buf_);
|
||||
}
|
||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
||||
bias_attr);
|
||||
auto reorder_desc =
|
||||
ReorderPd(this->cpu_engine_, input_bias_->get_desc(),
|
||||
this->cpu_engine_, scaled_bias_->get_desc(), bias_attr);
|
||||
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
||||
this->cpu_engine_, context);
|
||||
|
||||
@ -1754,7 +1601,7 @@ class MklQuantizedConv2DOp
|
||||
DCHECK(bias_tensor);
|
||||
TensorShape bias_tf_shape;
|
||||
bias_tf_shape.AddDim(
|
||||
(conv_prim_desc.PRIMITIVE_DESC_BIAS.get_size() / sizeof(Tbias)));
|
||||
(conv_prim_desc.bias_desc().get_size() / sizeof(Tbias)));
|
||||
OP_REQUIRES_OK(context, context->allocate_persistent(
|
||||
DataTypeToEnum<Tbias>::value, bias_tf_shape,
|
||||
&cached_bias_data_ptensor_, bias_tensor));
|
||||
@ -1787,7 +1634,7 @@ class MklQuantizedConv2DOp
|
||||
AllocatePersistentTensor(context, *conv_fwd_pd, &bias_tensor_ptr);
|
||||
void* cached_bias_data = const_cast<void*>(
|
||||
static_cast<const void*>(bias_tensor_ptr->flat<Tbias>().data()));
|
||||
size_t cached_bias_data_size = scaled_bias->GET_DESC.get_size();
|
||||
size_t cached_bias_data_size = scaled_bias->get_desc().get_size();
|
||||
memcpy(cached_bias_data, bias_data, cached_bias_data_size);
|
||||
}
|
||||
|
||||
@ -1822,7 +1669,7 @@ class MklQuantizedConv2DReluOp
|
||||
is_depthwise>::ExtendConvFwdParams(context, params);
|
||||
|
||||
params.post_op_params.push_back(
|
||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
}
|
||||
};
|
||||
|
||||
@ -1868,26 +1715,30 @@ class MklQuantizedConv2DSumReluOp
|
||||
// if summand_type is also DT_QUINT8 as the scale_output,
|
||||
// the scaling factor of 255.0f cancels each other and thus is avoided.
|
||||
// If it is not then it is DT_INT8 and is scaled appropriately.
|
||||
if (summand_type == DT_QUINT8)
|
||||
params.post_op_params.push_back(
|
||||
{"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}, ""});
|
||||
else
|
||||
if (summand_type == DT_QUINT8) {
|
||||
params.post_op_params.push_back({"sum",
|
||||
mkldnn::algorithm::undef,
|
||||
{scale_summand / scale_output},
|
||||
""});
|
||||
} else {
|
||||
params.post_op_params.push_back(
|
||||
{"sum",
|
||||
ALGORITHM_UNDEF,
|
||||
mkldnn::algorithm::undef,
|
||||
{255.0f * scale_summand / (scale_output * 127.0f)},
|
||||
""});
|
||||
}
|
||||
} else {
|
||||
params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}, ""});
|
||||
params.post_op_params.push_back(
|
||||
{"sum", mkldnn::algorithm::undef, {1.0}, ""});
|
||||
}
|
||||
params.post_op_params.push_back(
|
||||
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
{"activation", mkldnn::algorithm::eltwise_relu, {1.0, 0.0, 0.0}, ""});
|
||||
}
|
||||
|
||||
void AllocateOutputTensor(OpKernelContext* context,
|
||||
const ConvFwdPd& conv_prim_desc,
|
||||
const memory::dims& output_dims_mkl_order,
|
||||
MKL_TENSOR_FORMAT output_tf_format,
|
||||
MklTensorFormat output_tf_format,
|
||||
MklDnnShape* output_mkl_shape,
|
||||
Tensor** output_tensor) override {
|
||||
int summand_idx = context->num_inputs() / 2 - 1;
|
||||
@ -1966,21 +1817,17 @@ class MklQuantizedConv2DSumReluOp
|
||||
summand_mkl_shape.IsMklTensor()
|
||||
? summand_mkl_shape.GetMklLayout()
|
||||
: memory::desc(output_dims_mkl_order, MklDnnType<Tbias>(),
|
||||
MEMORY_FORMAT::nhwc);
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_);
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
memory::format_tag::nhwc);
|
||||
void* summand_buf =
|
||||
static_cast<void*>(const_cast<Tbias*>(summand.flat<Tbias>().data()));
|
||||
void* dst_buf =
|
||||
static_cast<void*>((*output_tensor)->flat<Ttemp_output>().data());
|
||||
summand_.reset(
|
||||
new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf));
|
||||
dst_.reset(new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST,
|
||||
this->cpu_engine_, dst_buf));
|
||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
|
||||
reorder_attr);
|
||||
summand_.reset(new memory(summand_md, this->cpu_engine_, summand_buf));
|
||||
dst_.reset(
|
||||
new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf));
|
||||
auto reorder_desc =
|
||||
ReorderPd(this->cpu_engine_, summand_md, this->cpu_engine_,
|
||||
conv_prim_desc.dst_desc(), reorder_attr);
|
||||
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
|
||||
context);
|
||||
}
|
||||
|
@ -42,20 +42,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
using mkldnn::convolution_direct;
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
#define MKLDNN_SIZE_DTYPE memory::dim
|
||||
#else
|
||||
#define MKLDNN_SIZE_DTYPE int
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
using ConvFwdDesc = mkldnn::convolution_forward::desc;
|
||||
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
||||
|
Loading…
Reference in New Issue
Block a user