Merge pull request #29431 from Intel-tensorflow:feature/wenxizhu/redundant-transpose-removal-conv3d
PiperOrigin-RevId: 255053294
This commit is contained in:
commit
78ff99a636
@ -650,6 +650,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// CheckForMklOp
|
||||
FuseConv2D,
|
||||
CopyAttrsConv});
|
||||
|
||||
// Transpose + Conv3d + Transpose:
|
||||
std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
|
||||
NCDHW::dim::H, NCDHW::dim::W,
|
||||
NCDHW::dim::C};
|
||||
std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
|
||||
NDHWC::dim::D, NDHWC::dim::H,
|
||||
NDHWC::dim::W};
|
||||
|
||||
auto CheckForTransposeToNDHWC =
|
||||
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
|
||||
auto CheckForConv3dOp =
|
||||
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
|
||||
auto CheckForTransposeToNCDHW =
|
||||
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
|
||||
auto FuseConv3D =
|
||||
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
|
||||
std::placeholders::_2, std::placeholders::_3, "NCDHW");
|
||||
|
||||
finfo_.push_back(
|
||||
{"transpose-elimination for Conv3D",
|
||||
{CheckForTransposeToNDHWC, CheckForConv3dOp, CheckForTransposeToNCDHW},
|
||||
// CheckForMklOp
|
||||
FuseConv3D,
|
||||
CopyAttrsConv});
|
||||
}
|
||||
|
||||
// Standard interface to run pass
|
||||
|
@ -1040,6 +1040,308 @@ TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv2DTranspose_Negative) {
|
||||
"_1:control;Transpose1->Relu;Transpose1:control->DMT/_2:control");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'Input0' op: 'Input'} \
|
||||
node { name: 'Input1' op: 'Input'} \
|
||||
node { name: 'Const0' op: 'Const' \
|
||||
attr { key: 'dtype' value { type: DT_INT32 } } \
|
||||
attr { \
|
||||
key: 'value' \
|
||||
value { \
|
||||
tensor { \
|
||||
dtype: DT_INT32 \
|
||||
tensor_shape { \
|
||||
dim { \
|
||||
size: 5 \
|
||||
} \
|
||||
} \
|
||||
tensor_content: \
|
||||
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \
|
||||
\\000\\000\\000\\001\\000\\000\\000' \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
node { name: 'Const1' op: 'Const' \
|
||||
attr { key: 'dtype' value { type: DT_INT32 } } \
|
||||
attr { \
|
||||
key: 'value' \
|
||||
value { \
|
||||
tensor { \
|
||||
dtype: DT_INT32 \
|
||||
tensor_shape { \
|
||||
dim { \
|
||||
size: 5 \
|
||||
} \
|
||||
} \
|
||||
tensor_content: \
|
||||
'\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002 \
|
||||
\\000\\000\\000\\003\\000\\000\\000' \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { \
|
||||
name: 'Transpose0' \
|
||||
op: 'Transpose' \
|
||||
input: 'Input0' \
|
||||
input: 'Const0' \
|
||||
attr { \
|
||||
key: 'T' \
|
||||
value { \
|
||||
type: DT_FLOAT \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'Tperm' \
|
||||
value { \
|
||||
type: DT_INT32 \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { \
|
||||
name: 'Conv3D' \
|
||||
op: 'Conv3D' \
|
||||
input: 'Transpose0' \
|
||||
input: 'Input1' \
|
||||
attr { \
|
||||
key: 'T' \
|
||||
value { \
|
||||
type: DT_FLOAT \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'data_format' \
|
||||
value { \
|
||||
s: 'NDHWC' \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'dilations' \
|
||||
value { \
|
||||
list { \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'padding' \
|
||||
value { \
|
||||
s: 'SAME' \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'strides' \
|
||||
value { \
|
||||
list { \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'use_cudnn_on_gpu' \
|
||||
value { \
|
||||
b: true \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { \
|
||||
name: 'Transpose1' \
|
||||
op: 'Transpose' \
|
||||
input: 'Conv3D' \
|
||||
input: 'Const1' \
|
||||
attr { \
|
||||
key: 'T' \
|
||||
value { \
|
||||
type: DT_FLOAT \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'Tperm' \
|
||||
value { \
|
||||
type: DT_INT32 \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { name: 'Relu' op: 'Relu'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['Transpose1'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);"
|
||||
"DMT/_1(Const);Input0(Input);Input1(Input);"
|
||||
"Relu(_MklRelu)|Conv3D->Relu;Conv3D:2->Relu:1;"
|
||||
"DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;Input0->Conv3D;"
|
||||
"Input0:control->DMT/_0:control;"
|
||||
"Input0:control->DMT/_1:control;Input1->Conv3D:1");
|
||||
}
|
||||
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Negative) {
|
||||
InitGraph(
|
||||
"node { name: 'Input0' op: 'Input'} \
|
||||
node { name: 'Input1' op: 'Input'} \
|
||||
node { name: 'Const0' op: 'Const' \
|
||||
attr { \
|
||||
key: 'dtype' \
|
||||
value { \
|
||||
type: DT_INT32 \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'value' \
|
||||
value { \
|
||||
tensor { \
|
||||
dtype: DT_INT32 \
|
||||
tensor_shape { \
|
||||
dim { \
|
||||
size: 5 \
|
||||
} \
|
||||
} \
|
||||
tensor_content: \
|
||||
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \
|
||||
\\000\\000\\000\\001\\000\\000\\000' \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
node { name: 'Const1' op: 'Const' \
|
||||
attr { \
|
||||
key: 'dtype' \
|
||||
value { \
|
||||
type: DT_INT32 \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'value' \
|
||||
value { \
|
||||
tensor { \
|
||||
dtype: DT_INT32 \
|
||||
tensor_shape { \
|
||||
dim { \
|
||||
size: 5 \
|
||||
} \
|
||||
} \
|
||||
tensor_content: \
|
||||
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \
|
||||
\\000\\000\\000\\001\\000\\000\\000' \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { \
|
||||
name: 'Transpose0' \
|
||||
op: 'Transpose' \
|
||||
input: 'Input0' \
|
||||
input: 'Const0' \
|
||||
attr { \
|
||||
key: 'T' \
|
||||
value { \
|
||||
type: DT_FLOAT \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'Tperm' \
|
||||
value { \
|
||||
type: DT_INT32 \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { \
|
||||
name: 'Conv3D' \
|
||||
op: 'Conv3D' \
|
||||
input: 'Transpose0' \
|
||||
input: 'Input1' \
|
||||
attr { \
|
||||
key: 'T' \
|
||||
value { \
|
||||
type: DT_FLOAT \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'data_format' \
|
||||
value { \
|
||||
s: 'NDHWC' \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'dilations' \
|
||||
value { \
|
||||
list { \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'padding' \
|
||||
value { \
|
||||
s: 'SAME' \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'strides' \
|
||||
value { \
|
||||
list { \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
i: 1 \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'use_cudnn_on_gpu' \
|
||||
value { \
|
||||
b: true \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { \
|
||||
name: 'Transpose1' \
|
||||
op: 'Transpose' \
|
||||
input: 'Conv3D' \
|
||||
input: 'Const1' \
|
||||
attr { \
|
||||
key: 'T' \
|
||||
value { \
|
||||
type: DT_FLOAT \
|
||||
} \
|
||||
} \
|
||||
attr { \
|
||||
key: 'Tperm' \
|
||||
value { \
|
||||
type: DT_INT32 \
|
||||
} \
|
||||
} \
|
||||
}"
|
||||
"node { name: 'Relu' op: 'Relu'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['Transpose1'] }");
|
||||
EXPECT_EQ(
|
||||
DoMklLayoutOptimizationPass(),
|
||||
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);Input0(Input);Input1(Input);"
|
||||
"Relu(_MklRelu);Transpose0(Transpose);"
|
||||
"Transpose1(Transpose)|Const0->Transpose0:1;"
|
||||
"Const1->Transpose1:1;Conv3D->Transpose1;"
|
||||
"DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;DMT/_2->Relu:1;"
|
||||
"Input0->Transpose0;Input1->Conv3D:1;Transpose0->Conv3D;"
|
||||
"Transpose0:control->DMT/_0:control;Transpose0:control->DMT/_1:control;"
|
||||
"Transpose1->Relu;Transpose1:control->DMT/_2:control");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to rewriting node to Mkl node
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
Loading…
x
Reference in New Issue
Block a user