Merge pull request #45817 from Intel-tensorflow:dnn0x_clean_matmul_fused_2
PiperOrigin-RevId: 350702539 Change-Id: I9cca95389c1717391675ee2c005bc42eb868be60
This commit is contained in:
commit
230472ae51
@ -109,17 +109,17 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
memory::dims weight_dims = memory::dims({channel, k});
|
memory::dims weight_dims = memory::dims({channel, k});
|
||||||
memory::dims bias_dims = memory::dims({channel});
|
memory::dims bias_dims = memory::dims({channel});
|
||||||
memory::dims dst_dims = memory::dims({batch, channel});
|
memory::dims dst_dims = memory::dims({batch, channel});
|
||||||
MEMORY_FORMAT src_format = MEMORY_FORMAT::nc;
|
memory::format_tag src_format = memory::format_tag::nc;
|
||||||
MEMORY_FORMAT weight_format =
|
memory::format_tag weight_format =
|
||||||
transpose_b_ ? MEMORY_FORMAT::oi : MEMORY_FORMAT::io;
|
transpose_b_ ? memory::format_tag::oi : memory::format_tag::io;
|
||||||
|
|
||||||
// Set weight format for primitive:
|
// Set weight format for primitive:
|
||||||
// 1. const, let MKL-DNN determine format because it will be cached;
|
// 1. const, let MKL-DNN determine format because it will be cached;
|
||||||
// 2. var, keep the original format to avoid reordering.
|
// 2. var, keep the original format to avoid reordering.
|
||||||
MklDnnMatMulFwdParams matmul_params(
|
MklDnnMatMulFwdParams matmul_params(
|
||||||
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
||||||
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
|
(this->is_weight_const_) ? memory::format_tag::any : weight_format,
|
||||||
MEMORY_FORMAT::nc);
|
memory::format_tag::nc);
|
||||||
|
|
||||||
// Extend the basic parameters for data types and fusions.
|
// Extend the basic parameters for data types and fusions.
|
||||||
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
|
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
|
||||||
@ -160,7 +160,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
output_tf_shape, output_mkl_shape,
|
output_tf_shape, output_mkl_shape,
|
||||||
native_format);
|
native_format);
|
||||||
auto output_format_tag =
|
auto output_format_tag =
|
||||||
MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
|
MklTensorFormatToMklDnnDataFormat(MklTensorFormat::FORMAT_NC);
|
||||||
auto add_md =
|
auto add_md =
|
||||||
add_mkl_shape.IsMklTensor()
|
add_mkl_shape.IsMklTensor()
|
||||||
? add_mkl_shape.GetMklLayout()
|
? add_mkl_shape.GetMklLayout()
|
||||||
@ -180,12 +180,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
mkldnn::memory::format_tag::x);
|
mkldnn::memory::format_tag::x);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto fuse_add_src_ =
|
auto fuse_add_src_ = memory(add_md, this->cpu_engine_, add_buf);
|
||||||
MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
|
auto fuse_add_dst_ = memory(dst_md, this->cpu_engine_, dst_buf);
|
||||||
auto fuse_add_dst_ =
|
|
||||||
MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
|
|
||||||
auto reorder_desc =
|
auto reorder_desc =
|
||||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
|
||||||
|
|
||||||
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
|
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
|
||||||
this->cpu_engine_, ctx);
|
this->cpu_engine_, ctx);
|
||||||
@ -215,19 +213,17 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
? src_mkl_shape.GetMklLayout()
|
? src_mkl_shape.GetMklLayout()
|
||||||
: memory::desc(src_dims, MklDnnType<T>(), src_format);
|
: memory::desc(src_dims, MklDnnType<T>(), src_format);
|
||||||
|
|
||||||
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
|
if (src_md != matmul_pd->src_desc()) {
|
||||||
src_mkl.SetUsrMem(src_md, src_data);
|
src_mkl.SetUsrMem(src_md, src_data);
|
||||||
src_mkl.CheckReorderToOpMem(
|
src_mkl.CheckReorderToOpMem(matmul_pd.get()->src_desc(),
|
||||||
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_SRC,
|
this->cpu_engine_, ctx);
|
||||||
this->cpu_engine_),
|
|
||||||
ctx);
|
|
||||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get cached data when weight is const.
|
// Get cached data when weight is const.
|
||||||
const memory::desc weight_md =
|
const memory::desc weight_md =
|
||||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||||
if (IS_WEIGHTS_REORDER_NEEDED(weight_md, matmul_pd, matmul_prim)) {
|
if (weight_md != matmul_pd->weights_desc()) {
|
||||||
T* cached_weight_data = nullptr;
|
T* cached_weight_data = nullptr;
|
||||||
|
|
||||||
if (this->is_weight_const_) {
|
if (this->is_weight_const_) {
|
||||||
@ -235,13 +231,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
|
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
|
||||||
weight_mkl, weight_md);
|
weight_mkl, weight_md);
|
||||||
}
|
}
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
cached_weight_data =
|
||||||
cached_weight_data = this->GetCachedWeight(
|
this->GetCachedWeight(ctx, matmul_pd->weights_desc());
|
||||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd));
|
|
||||||
#else
|
|
||||||
cached_weight_data = this->GetCachedWeight(
|
|
||||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd).desc());
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache weight may fail when it gets different format in different
|
// Cache weight may fail when it gets different format in different
|
||||||
@ -251,10 +242,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
|||||||
weight_data = cached_weight_data;
|
weight_data = cached_weight_data;
|
||||||
} else {
|
} else {
|
||||||
weight_mkl.SetUsrMem(weight_md, weight_data);
|
weight_mkl.SetUsrMem(weight_md, weight_data);
|
||||||
weight_mkl.CheckReorderToOpMem(
|
weight_mkl.CheckReorderToOpMem(matmul_pd.get()->weights_desc(),
|
||||||
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS,
|
this->cpu_engine_, ctx);
|
||||||
this->cpu_engine_),
|
|
||||||
ctx);
|
|
||||||
weight_data =
|
weight_data =
|
||||||
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
|
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
|
||||||
}
|
}
|
||||||
|
@ -406,12 +406,12 @@ class MklQuantizeV2Op : public OpKernel {
|
|||||||
TensorShape output_tf_shape;
|
TensorShape output_tf_shape;
|
||||||
if (src_mkl_shape.IsMklTensor()) {
|
if (src_mkl_shape.IsMklTensor()) {
|
||||||
output_mkl_shape.SetMklTensor(true);
|
output_mkl_shape.SetMklTensor(true);
|
||||||
output_mkl_shape.SetMklLayout(&DST_MD);
|
output_mkl_shape.SetMklLayout(&dst_md);
|
||||||
output_mkl_shape.SetElemType(MklDnnType<T>());
|
output_mkl_shape.SetElemType(MklDnnType<T>());
|
||||||
output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(),
|
output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(),
|
||||||
src_mkl_shape.GetSizesAsMklDnnDims(),
|
src_mkl_shape.GetSizesAsMklDnnDims(),
|
||||||
src_mkl_shape.GetTfDataFormat());
|
src_mkl_shape.GetTfDataFormat());
|
||||||
output_tf_shape.AddDim(DST_MD.get_size() / sizeof(T));
|
output_tf_shape.AddDim(dst_md.get_size() / sizeof(T));
|
||||||
} else {
|
} else {
|
||||||
output_mkl_shape.SetMklTensor(false);
|
output_mkl_shape.SetMklTensor(false);
|
||||||
output_tf_shape = MklDnnDimsToTFShape(output_dims);
|
output_tf_shape = MklDnnDimsToTFShape(output_dims);
|
||||||
|
Loading…
Reference in New Issue
Block a user