Merge pull request from Intel-tensorflow:dnn0x_clean_matmul_fused_2

PiperOrigin-RevId: 350702539
Change-Id: I9cca95389c1717391675ee2c005bc42eb868be60
This commit is contained in:
TensorFlower Gardener 2021-01-07 22:41:57 -08:00
commit 230472ae51
2 changed files with 19 additions and 30 deletions
tensorflow/core/kernels/mkl

View File

@ -109,17 +109,17 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
memory::dims weight_dims = memory::dims({channel, k});
memory::dims bias_dims = memory::dims({channel});
memory::dims dst_dims = memory::dims({batch, channel});
MEMORY_FORMAT src_format = MEMORY_FORMAT::nc;
MEMORY_FORMAT weight_format =
transpose_b_ ? MEMORY_FORMAT::oi : MEMORY_FORMAT::io;
memory::format_tag src_format = memory::format_tag::nc;
memory::format_tag weight_format =
transpose_b_ ? memory::format_tag::oi : memory::format_tag::io;
// Set weight format for primitive:
// 1. const, let MKL-DNN determine format because it will be cached;
// 2. var, keep the original format to avoid reordering.
MklDnnMatMulFwdParams matmul_params(
src_dims, weight_dims, bias_dims, dst_dims, src_format,
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
MEMORY_FORMAT::nc);
(this->is_weight_const_) ? memory::format_tag::any : weight_format,
memory::format_tag::nc);
// Extend the basic parameters for data types and fusions.
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
@ -160,7 +160,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
output_tf_shape, output_mkl_shape,
native_format);
auto output_format_tag =
MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
MklTensorFormatToMklDnnDataFormat(MklTensorFormat::FORMAT_NC);
auto add_md =
add_mkl_shape.IsMklTensor()
? add_mkl_shape.GetMklLayout()
@ -180,12 +180,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
mkldnn::memory::format_tag::x);
}
auto fuse_add_src_ =
MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
auto fuse_add_dst_ =
MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
auto fuse_add_src_ = memory(add_md, this->cpu_engine_, add_buf);
auto fuse_add_dst_ = memory(dst_md, this->cpu_engine_, dst_buf);
auto reorder_desc =
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
this->cpu_engine_, ctx);
@ -215,19 +213,17 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), src_format);
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
if (src_md != matmul_pd->src_desc()) {
src_mkl.SetUsrMem(src_md, src_data);
src_mkl.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_SRC,
this->cpu_engine_),
ctx);
src_mkl.CheckReorderToOpMem(matmul_pd.get()->src_desc(),
this->cpu_engine_, ctx);
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
}
// Get cached data when weight is const.
const memory::desc weight_md =
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
if (IS_WEIGHTS_REORDER_NEEDED(weight_md, matmul_pd, matmul_prim)) {
if (weight_md != matmul_pd->weights_desc()) {
T* cached_weight_data = nullptr;
if (this->is_weight_const_) {
@ -235,13 +231,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
weight_mkl, weight_md);
}
#ifdef ENABLE_MKLDNN_V1
cached_weight_data = this->GetCachedWeight(
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd));
#else
cached_weight_data = this->GetCachedWeight(
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd).desc());
#endif
cached_weight_data =
this->GetCachedWeight(ctx, matmul_pd->weights_desc());
}
// Cache weight may fail when it gets different format in different
@ -251,10 +242,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
weight_data = cached_weight_data;
} else {
weight_mkl.SetUsrMem(weight_md, weight_data);
weight_mkl.CheckReorderToOpMem(
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS,
this->cpu_engine_),
ctx);
weight_mkl.CheckReorderToOpMem(matmul_pd.get()->weights_desc(),
this->cpu_engine_, ctx);
weight_data =
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
}

View File

@ -406,12 +406,12 @@ class MklQuantizeV2Op : public OpKernel {
TensorShape output_tf_shape;
if (src_mkl_shape.IsMklTensor()) {
output_mkl_shape.SetMklTensor(true);
output_mkl_shape.SetMklLayout(&DST_MD);
output_mkl_shape.SetMklLayout(&dst_md);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(),
src_mkl_shape.GetSizesAsMklDnnDims(),
src_mkl_shape.GetTfDataFormat());
output_tf_shape.AddDim(DST_MD.get_size() / sizeof(T));
output_tf_shape.AddDim(dst_md.get_size() / sizeof(T));
} else {
output_mkl_shape.SetMklTensor(false);
output_tf_shape = MklDnnDimsToTFShape(output_dims);