Merge pull request #45817 from Intel-tensorflow:dnn0x_clean_matmul_fused_2
PiperOrigin-RevId: 350702539 Change-Id: I9cca95389c1717391675ee2c005bc42eb868be60
This commit is contained in:
commit
230472ae51
tensorflow/core/kernels/mkl
@ -109,17 +109,17 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
memory::dims weight_dims = memory::dims({channel, k});
|
||||
memory::dims bias_dims = memory::dims({channel});
|
||||
memory::dims dst_dims = memory::dims({batch, channel});
|
||||
MEMORY_FORMAT src_format = MEMORY_FORMAT::nc;
|
||||
MEMORY_FORMAT weight_format =
|
||||
transpose_b_ ? MEMORY_FORMAT::oi : MEMORY_FORMAT::io;
|
||||
memory::format_tag src_format = memory::format_tag::nc;
|
||||
memory::format_tag weight_format =
|
||||
transpose_b_ ? memory::format_tag::oi : memory::format_tag::io;
|
||||
|
||||
// Set weight format for primitive:
|
||||
// 1. const, let MKL-DNN determine format because it will be cached;
|
||||
// 2. var, keep the original format to avoid reordering.
|
||||
MklDnnMatMulFwdParams matmul_params(
|
||||
src_dims, weight_dims, bias_dims, dst_dims, src_format,
|
||||
(this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format,
|
||||
MEMORY_FORMAT::nc);
|
||||
(this->is_weight_const_) ? memory::format_tag::any : weight_format,
|
||||
memory::format_tag::nc);
|
||||
|
||||
// Extend the basic parameters for data types and fusions.
|
||||
ExtendMklDnnMatMulFwdParams(ctx, matmul_params);
|
||||
@ -160,7 +160,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
output_tf_shape, output_mkl_shape,
|
||||
native_format);
|
||||
auto output_format_tag =
|
||||
MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC);
|
||||
MklTensorFormatToMklDnnDataFormat(MklTensorFormat::FORMAT_NC);
|
||||
auto add_md =
|
||||
add_mkl_shape.IsMklTensor()
|
||||
? add_mkl_shape.GetMklLayout()
|
||||
@ -180,12 +180,10 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
mkldnn::memory::format_tag::x);
|
||||
}
|
||||
|
||||
auto fuse_add_src_ =
|
||||
MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf);
|
||||
auto fuse_add_dst_ =
|
||||
MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf);
|
||||
auto fuse_add_src_ = memory(add_md, this->cpu_engine_, add_buf);
|
||||
auto fuse_add_dst_ = memory(dst_md, this->cpu_engine_, dst_buf);
|
||||
auto reorder_desc =
|
||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||
ReorderPd(this->cpu_engine_, add_md, this->cpu_engine_, dst_md);
|
||||
|
||||
CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_,
|
||||
this->cpu_engine_, ctx);
|
||||
@ -215,19 +213,17 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
? src_mkl_shape.GetMklLayout()
|
||||
: memory::desc(src_dims, MklDnnType<T>(), src_format);
|
||||
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) {
|
||||
if (src_md != matmul_pd->src_desc()) {
|
||||
src_mkl.SetUsrMem(src_md, src_data);
|
||||
src_mkl.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_SRC,
|
||||
this->cpu_engine_),
|
||||
ctx);
|
||||
src_mkl.CheckReorderToOpMem(matmul_pd.get()->src_desc(),
|
||||
this->cpu_engine_, ctx);
|
||||
src_data = reinterpret_cast<T*>(src_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
|
||||
// Get cached data when weight is const.
|
||||
const memory::desc weight_md =
|
||||
memory::desc(weight_dims, MklDnnType<T>(), weight_format);
|
||||
if (IS_WEIGHTS_REORDER_NEEDED(weight_md, matmul_pd, matmul_prim)) {
|
||||
if (weight_md != matmul_pd->weights_desc()) {
|
||||
T* cached_weight_data = nullptr;
|
||||
|
||||
if (this->is_weight_const_) {
|
||||
@ -235,13 +231,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor,
|
||||
weight_mkl, weight_md);
|
||||
}
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
cached_weight_data = this->GetCachedWeight(
|
||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd));
|
||||
#else
|
||||
cached_weight_data = this->GetCachedWeight(
|
||||
ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd).desc());
|
||||
#endif
|
||||
cached_weight_data =
|
||||
this->GetCachedWeight(ctx, matmul_pd->weights_desc());
|
||||
}
|
||||
|
||||
// Cache weight may fail when it gets different format in different
|
||||
@ -251,10 +242,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
weight_data = cached_weight_data;
|
||||
} else {
|
||||
weight_mkl.SetUsrMem(weight_md, weight_data);
|
||||
weight_mkl.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS,
|
||||
this->cpu_engine_),
|
||||
ctx);
|
||||
weight_mkl.CheckReorderToOpMem(matmul_pd.get()->weights_desc(),
|
||||
this->cpu_engine_, ctx);
|
||||
weight_data =
|
||||
reinterpret_cast<T*>(weight_mkl.GetOpMem().get_data_handle());
|
||||
}
|
||||
|
@ -406,12 +406,12 @@ class MklQuantizeV2Op : public OpKernel {
|
||||
TensorShape output_tf_shape;
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
output_mkl_shape.SetMklTensor(true);
|
||||
output_mkl_shape.SetMklLayout(&DST_MD);
|
||||
output_mkl_shape.SetMklLayout(&dst_md);
|
||||
output_mkl_shape.SetElemType(MklDnnType<T>());
|
||||
output_mkl_shape.SetTfLayout(src_mkl_shape.GetDimension(),
|
||||
src_mkl_shape.GetSizesAsMklDnnDims(),
|
||||
src_mkl_shape.GetTfDataFormat());
|
||||
output_tf_shape.AddDim(DST_MD.get_size() / sizeof(T));
|
||||
output_tf_shape.AddDim(dst_md.get_size() / sizeof(T));
|
||||
} else {
|
||||
output_mkl_shape.SetMklTensor(false);
|
||||
output_tf_shape = MklDnnDimsToTFShape(output_dims);
|
||||
|
Loading…
Reference in New Issue
Block a user