diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 036f900a580..641b586e7d6 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -650,6 +650,31 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // CheckForMklOp FuseConv2D, CopyAttrsConv}); + + // Transpose + Conv3d + Transpose: + std::vector transpose_to_ndhwc = {NCDHW::dim::N, NCDHW::dim::D, + NCDHW::dim::H, NCDHW::dim::W, + NCDHW::dim::C}; + std::vector transpose_to_ncdhw = {NDHWC::dim::N, NDHWC::dim::C, + NDHWC::dim::D, NDHWC::dim::H, + NDHWC::dim::W}; + + auto CheckForTransposeToNDHWC = + std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ndhwc); + auto CheckForConv3dOp = + std::bind(CheckForMklOp, std::placeholders::_1, csinfo_.conv3d); + auto CheckForTransposeToNCDHW = + std::bind(CheckForTranspose, std::placeholders::_1, transpose_to_ncdhw); + auto FuseConv3D = + std::bind(FuseTransposeMklOpTranspose, std::placeholders::_1, + std::placeholders::_2, std::placeholders::_3, "NCDHW"); + + finfo_.push_back( + {"transpose-elimination for Conv3D", + {CheckForTransposeToNDHWC, CheckForConv3dOp, CheckForTransposeToNCDHW}, + // CheckForMklOp + FuseConv3D, + CopyAttrsConv}); } // Standard interface to run pass diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 89754d8bb96..b48df7f74bf 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -1040,6 +1040,308 @@ TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv2DTranspose_Negative) { "_1:control;Transpose1->Relu;Transpose1:control->DMT/_2:control"); } +TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Positive) { + InitGraph( + "node { name: 'Input0' op: 'Input'} \ + node { name: 'Input1' op: 'Input'} \ + node { name: 'Const0' op: 'Const' \ + attr { key: 'dtype' value { type: DT_INT32 } } \ + attr { \ + key: 'value' \ + value { \ + tensor { \ + dtype: DT_INT32 \ + tensor_shape { \ + dim { \ + size: 5 \ + } \ + } \ + tensor_content: \ + '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \ + \\000\\000\\000\\001\\000\\000\\000' \ + } \ + } \ + } \ + } \ + node { name: 'Const1' op: 'Const' \ + attr { key: 'dtype' value { type: DT_INT32 } } \ + attr { \ + key: 'value' \ + value { \ + tensor { \ + dtype: DT_INT32 \ + tensor_shape { \ + dim { \ + size: 5 \ + } \ + } \ + tensor_content: \ + '\\000\\000\\000\\000\\004\\000\\000\\000\\001\\000\\000\\000\\002 \ + \\000\\000\\000\\003\\000\\000\\000' \ + } \ + } \ + } \ + }" + "node { \ + name: 'Transpose0' \ + op: 'Transpose' \ + input: 'Input0' \ + input: 'Const0' \ + attr { \ + key: 'T' \ + value { \ + type: DT_FLOAT \ + } \ + } \ + attr { \ + key: 'Tperm' \ + value { \ + type: DT_INT32 \ + } \ + } \ + }" + "node { \ + name: 'Conv3D' \ + op: 'Conv3D' \ + input: 'Transpose0' \ + input: 'Input1' \ + attr { \ + key: 'T' \ + value { \ + type: DT_FLOAT \ + } \ + } \ + attr { \ + key: 'data_format' \ + value { \ + s: 'NDHWC' \ + } \ + } \ + attr { \ + key: 'dilations' \ + value { \ + list { \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + } \ + } \ + } \ + attr { \ + key: 'padding' \ + value { \ + s: 'SAME' \ + } \ + } \ + attr { \ + key: 'strides' \ + value { \ + list { \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + } \ + } \ + } \ + attr { \ + key: 'use_cudnn_on_gpu' \ + value { \ + b: true \ + } \ + } \ + }" + "node { \ + name: 'Transpose1' \ + op: 'Transpose' \ + input: 'Conv3D' \ + input: 'Const1' \ + attr { \ + key: 'T' \ + value { \ + type: DT_FLOAT \ + } \ + } \ + attr { \ + key: 'Tperm' \ + value { \ + type: DT_INT32 \ + } \ + } \ + }" + "node { name: 'Relu' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['Transpose1'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);" + "DMT/_1(Const);Input0(Input);Input1(Input);" + "Relu(_MklRelu)|Conv3D->Relu;Conv3D:2->Relu:1;" + "DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;Input0->Conv3D;" + "Input0:control->DMT/_0:control;" + "Input0:control->DMT/_1:control;Input1->Conv3D:1"); +} + +TEST_F(MklLayoutPassTest, NodeMerge_TransposeConv3DTranspose_Negative) { + InitGraph( + "node { name: 'Input0' op: 'Input'} \ + node { name: 'Input1' op: 'Input'} \ + node { name: 'Const0' op: 'Const' \ + attr { \ + key: 'dtype' \ + value { \ + type: DT_INT32 \ + } \ + } \ + attr { \ + key: 'value' \ + value { \ + tensor { \ + dtype: DT_INT32 \ + tensor_shape { \ + dim { \ + size: 5 \ + } \ + } \ + tensor_content: \ + '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \ + \\000\\000\\000\\001\\000\\000\\000' \ + } \ + } \ + } \ + } \ + node { name: 'Const1' op: 'Const' \ + attr { \ + key: 'dtype' \ + value { \ + type: DT_INT32 \ + } \ + } \ + attr { \ + key: 'value' \ + value { \ + tensor { \ + dtype: DT_INT32 \ + tensor_shape { \ + dim { \ + size: 5 \ + } \ + } \ + tensor_content: \ + '\\000\\000\\000\\000\\002\\000\\000\\000\\003\\000\\000\\000\\004 \ + \\000\\000\\000\\001\\000\\000\\000' \ + } \ + } \ + } \ + }" + "node { \ + name: 'Transpose0' \ + op: 'Transpose' \ + input: 'Input0' \ + input: 'Const0' \ + attr { \ + key: 'T' \ + value { \ + type: DT_FLOAT \ + } \ + } \ + attr { \ + key: 'Tperm' \ + value { \ + type: DT_INT32 \ + } \ + } \ + }" + "node { \ + name: 'Conv3D' \ + op: 'Conv3D' \ + input: 'Transpose0' \ + input: 'Input1' \ + attr { \ + key: 'T' \ + value { \ + type: DT_FLOAT \ + } \ + } \ + attr { \ + key: 'data_format' \ + value { \ + s: 'NDHWC' \ + } \ + } \ + attr { \ + key: 'dilations' \ + value { \ + list { \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + } \ + } \ + } \ + attr { \ + key: 'padding' \ + value { \ + s: 'SAME' \ + } \ + } \ + attr { \ + key: 'strides' \ + value { \ + list { \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + i: 1 \ + } \ + } \ + } \ + attr { \ + key: 'use_cudnn_on_gpu' \ + value { \ + b: true \ + } \ + } \ + }" + "node { \ + name: 'Transpose1' \ + op: 'Transpose' \ + input: 'Conv3D' \ + input: 'Const1' \ + attr { \ + key: 'T' \ + value { \ + type: DT_FLOAT \ + } \ + } \ + attr { \ + key: 'Tperm' \ + value { \ + type: DT_INT32 \ + } \ + } \ + }" + "node { name: 'Relu' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['Transpose1'] }"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "Const0(Const);Const1(Const);Conv3D(_MklConv3D);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);Input0(Input);Input1(Input);" + "Relu(_MklRelu);Transpose0(Transpose);" + "Transpose1(Transpose)|Const0->Transpose0:1;" + "Const1->Transpose1:1;Conv3D->Transpose1;" + "DMT/_0->Conv3D:2;DMT/_1->Conv3D:3;DMT/_2->Relu:1;" + "Input0->Transpose0;Input1->Conv3D:1;Transpose0->Conv3D;" + "Transpose0:control->DMT/_0:control;Transpose0:control->DMT/_1:control;" + "Transpose1->Relu;Transpose1:control->DMT/_2:control"); +} + ///////////////////////////////////////////////////////////////////// // Unit tests related to rewriting node to Mkl node /////////////////////////////////////////////////////////////////////