Merge pull request #29431 from Intel-tensorflow:feature/wenxizhu/redundant-transpose-removal-conv3d

PiperOrigin-RevId: 255053294
This commit is contained in:
TensorFlower Gardener 2019-06-25 23:43:01 -07:00
commit 78ff99a636
2 changed files with 327 additions and 0 deletions

View File

@ -650,6 +650,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// CheckForMklOp
FuseConv2D,
CopyAttrsConv});
// Transpose + Conv3d + Transpose:
std::vector<int> transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D,
NCDHW::dim::H, NCDHW::dim::W,
NCDHW::dim::C};
std::vector<int> transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C,
NDHWC::dim::D, NDHWC::dim::H,
NDHWC::dim::W};
auto CheckForTransposeToNDHWC =
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc);
auto CheckForConv3dOp =
std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d);
auto CheckForTransposeToNCDHW =
std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw);
auto FuseConv3D =
std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1,
std::placeholders::_2, std::placeholders::_3, "NCDHW");
finfo_.push_back(
{"transpose-elimination for Conv3D",
{CheckForTransposeToNDHWC, CheckForConv3dOp, CheckForTransposeToNCDHW},
// CheckForMklOp
FuseConv3D,
CopyAttrsConv});
}
// Standard interface to run pass

View File

@ -1040,6 +1040,308 @@ TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv2DTranspose_Negative) {
"_1:control;Transpose1->Relu;Transpose1:control->DMT/_2:control");
}
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Positive) {
InitGraph(
"node { name: 'Input0' op: 'Input'} \
node { name: 'Input1' op: 'Input'} \
node { name: 'Const0' op: 'Const' \
attr { key: 'dtype' value { type: DT_INT32 } } \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \
\\000\\000\\000\\001\\000\\000\\000' \
} \
} \
} \
} \
node { name: 'Const1' op: 'Const' \
attr { key: 'dtype' value { type: DT_INT32 } } \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002 \
\\000\\000\\000\\003\\000\\000\\000' \
} \
} \
} \
}"
"node { \
name: 'Transpose0' \
op: 'Transpose' \
input: 'Input0' \
input: 'Const0' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { \
name: 'Conv3D' \
op: 'Conv3D' \
input: 'Transpose0' \
input: 'Input1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'data_format' \
value { \
s: 'NDHWC' \
} \
} \
attr { \
key: 'dilations' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'padding' \
value { \
s: 'SAME' \
} \
} \
attr { \
key: 'strides' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'use_cudnn_on_gpu' \
value { \
b: true \
} \
} \
}"
"node { \
name: 'Transpose1' \
op: 'Transpose' \
input: 'Conv3D' \
input: 'Const1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { name: 'Relu' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['Transpose1'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);"
"DMT/_1(Const);Input0(Input);Input1(Input);"
"Relu(_MklRelu)|Conv3D->Relu;Conv3D:2->Relu:1;"
"DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;Input0->Conv3D;"
"Input0:control->DMT/_0:control;"
"Input0:control->DMT/_1:control;Input1->Conv3D:1");
}
TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Negative) {
InitGraph(
"node { name: 'Input0' op: 'Input'} \
node { name: 'Input1' op: 'Input'} \
node { name: 'Const0' op: 'Const' \
attr { \
key: 'dtype' \
value { \
type: DT_INT32 \
} \
} \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \
\\000\\000\\000\\001\\000\\000\\000' \
} \
} \
} \
} \
node { name: 'Const1' op: 'Const' \
attr { \
key: 'dtype' \
value { \
type: DT_INT32 \
} \
} \
attr { \
key: 'value' \
value { \
tensor { \
dtype: DT_INT32 \
tensor_shape { \
dim { \
size: 5 \
} \
} \
tensor_content: \
'\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \
\\000\\000\\000\\001\\000\\000\\000' \
} \
} \
} \
}"
"node { \
name: 'Transpose0' \
op: 'Transpose' \
input: 'Input0' \
input: 'Const0' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { \
name: 'Conv3D' \
op: 'Conv3D' \
input: 'Transpose0' \
input: 'Input1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'data_format' \
value { \
s: 'NDHWC' \
} \
} \
attr { \
key: 'dilations' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'padding' \
value { \
s: 'SAME' \
} \
} \
attr { \
key: 'strides' \
value { \
list { \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
i: 1 \
} \
} \
} \
attr { \
key: 'use_cudnn_on_gpu' \
value { \
b: true \
} \
} \
}"
"node { \
name: 'Transpose1' \
op: 'Transpose' \
input: 'Conv3D' \
input: 'Const1' \
attr { \
key: 'T' \
value { \
type: DT_FLOAT \
} \
} \
attr { \
key: 'Tperm' \
value { \
type: DT_INT32 \
} \
} \
}"
"node { name: 'Relu' op: 'Relu'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['Transpose1'] }");
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);Input0(Input);Input1(Input);"
"Relu(_MklRelu);Transpose0(Transpose);"
"Transpose1(Transpose)|Const0->Transpose0:1;"
"Const1->Transpose1:1;Conv3D->Transpose1;"
"DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;DMT/_2->Relu:1;"
"Input0->Transpose0;Input1->Conv3D:1;Transpose0->Conv3D;"
"Transpose0:control->DMT/_0:control;Transpose0:control->DMT/_1:control;"
"Transpose1->Relu;Transpose1:control->DMT/_2:control");
}
/////////////////////////////////////////////////////////////////////
// Unit tests related to rewriting node to Mkl node
/////////////////////////////////////////////////////////////////////