diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index c27c7aa911b..b1d692da408 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3"; csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3"; csinfo_.fused_conv2d = "_FusedConv2D"; + csinfo_.fused_depthwise_conv2d = "_FusedDepthwiseConv2dNative"; csinfo_.fused_matmul = "_FusedMatMul"; csinfo_.identity = "Identity"; csinfo_.leakyrelu = "LeakyRelu"; @@ -295,6 +296,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_depthwise_conv2d_grad_filter = "_MklDepthwiseConv2dNativeBackpropFilter"; csinfo_.mkl_fused_conv2d = "_MklFusedConv2D"; + csinfo_.mkl_fused_depthwise_conv2d = "_MklFusedDepthwiseConv2dNative"; csinfo_.mkl_fused_matmul = "_MklFusedMatMul"; csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D"; csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D"; @@ -479,6 +481,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d, CopyAttrsFusedConv2D, FusedConv2DRewrite, kRewriteForLayoutPropagation}); + rinfo_.push_back({csinfo_.fused_depthwise_conv2d, + csinfo_.mkl_fused_depthwise_conv2d, CopyAttrsFusedConv2D, + FusedDepthwiseConv2DRewrite, + kRewriteForLayoutPropagation}); rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul, CopyAttrsAllCheckConstFilter, FusedMatMulRewrite}); @@ -925,6 +931,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string fused_batch_norm_v3; string fused_batch_norm_grad_v3; string fused_conv2d; + string fused_depthwise_conv2d; string fused_matmul; string identity; string leakyrelu; @@ -945,6 +952,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string mkl_depthwise_conv2d_grad_input; string mkl_depthwise_conv2d_grad_filter; string mkl_fused_conv2d; + string mkl_fused_depthwise_conv2d; string mkl_fused_matmul; string mkl_pad_with_conv2d; string mkl_pad_with_fused_conv2d; @@ -1675,6 +1683,25 @@ class MklLayoutRewritePass : public GraphOptimizationPass { fused_ops == std::vector{"BiasAdd", "Add", "Relu"}); } + static bool FusedDepthwiseConv2DRewrite(const Node* n) { + // MKL DNN currently doesn't support all fusions that grappler fuses + // together with DepthwiseConv2D (ex. batchnorm). We rewrite + // _FusedDepthwiseConv2DNative only if it includes those we support. + DataType T; + if (!TryGetNodeAttr(n->def(), "T", &T) || + !mkl_op_registry::IsMklLayoutDependentOp( + csinfo_.mkl_fused_depthwise_conv2d, T)) { + return false; + } + + std::vector fused_ops; + TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops)); + return (fused_ops == std::vector{"BiasAdd"} || + fused_ops == std::vector{"BiasAdd", "Relu"} || + fused_ops == std::vector{"BiasAdd", "Relu6"} || + fused_ops == std::vector{"BiasAdd", "Elu"}); + } + // Rewrites input node to a new node specified by its matching rewrite info. // // Method first searches matching rewrite info for input node and then @@ -3703,6 +3730,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { n->type_string() != csinfo_.pad_with_fused_conv2d && n->type_string() != csinfo_.conv2d_grad_filter_with_bias && n->type_string() != csinfo_.fused_conv2d && + n->type_string() != csinfo_.fused_depthwise_conv2d && n->type_string() != csinfo_.fused_matmul && !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 6fe969a99c3..f66b3fca0f6 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -1789,6 +1789,56 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive6); REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7); #undef REGISTER_TEST +// Rewrite test for _FusedDepthwiseConv2dNative Op fusion +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: '" #INPUT "'}" \ + "node { name: 'C' op: '" #INPUT "'}" \ + "node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'num_args' value { i: 1 } }" \ + " attr { key: 'data_format' value { s: 'NCHW' } }" \ + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \ + "} }" \ + " attr { key: 'padding' value { s: 'SAME' } }" \ + " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \ + "} }" \ + " attr { key: 'fused_ops' value { list: " FUSED_OPS " } }" \ + " attr { key: 'epsilon' value { f: 0.001 }}" \ + " input: ['A', 'B', 'C']}" \ + "node { name: 'E' op: 'Zeta'" \ + "attr { key: 'T' value { type: " #T " } }" \ + " input: ['D', 'C'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \ + "D(_MklFusedDepthwiseConv2dNative);" \ + "DMT/_0(Const);DMT/_1(Const);DMT/_2(Const);E(Zeta)|A->D;" \ + "A:control->DMT/_0:control;A:control->DMT/_1:control;" \ + "A:control->DMT/_2:control;B->D:1;C->D:2;C->E:1;D->E;" \ + "DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); \ + } + +// BiasAdd fusion +#define FUSED_OPS "{s: 'BiasAdd'}" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive1); + +// BiasAdd + Relu fusion +#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu'}" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive2); + +// BiasAdd + Relu6 fusion +#define FUSED_OPS "{s: 'BiasAdd', s: 'Relu6'}" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive3); + +// BiasAdd + Elu fusion +#define FUSED_OPS "{s: 'BiasAdd', s: 'Elu'}" +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Positive4); + +#undef FUSED_OPS +#undef REGISTER_TEST + // Rewrite test for _FusedConv2D Op with unsupported fusion #define REGISTER_TEST(NAME, T, INPUT) \ TEST_F(MklLayoutPassTest, NAME##_##T) { \ @@ -1818,6 +1868,36 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Positive7); REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1); #undef REGISTER_TEST +// Rewrite test for _FusedDepthwiseConv2dNative with unsupported fusion +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: '" #INPUT "'}" \ + "node { name: 'C' op: '" #INPUT "'}" \ + "node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \ + " attr { key: 'T' value { type: " #T " } }" \ + " attr { key: 'num_args' value { i: 1 } }" \ + " attr { key: 'data_format' value { s: 'NCHW' } }" \ + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \ + "} }" \ + " attr { key: 'padding' value { s: 'SAME' } }" \ + " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \ + "} }" \ + " attr { key: 'fused_ops' value { list: {s: 'Unsupported'} } }" \ + " attr { key: 'epsilon' value { f: 0.001 }}" \ + " input: ['A', 'B', 'C']}" \ + "node { name: 'E' op: 'Zeta'" \ + "attr { key: 'T' value { type: " #T " } }" \ + " input: ['D', 'C'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \ + "D(_FusedDepthwiseConv2dNative);" \ + "E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \ + } +REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedDepthwiseConv2dNative_Negative1); +#undef REGISTER_TEST + // Rewrite test for _FusedConv2D Op with unsupported type #define REGISTER_TEST(NAME, T, INPUT) \ TEST_F(MklLayoutPassTest, NAME##_##T) { \ @@ -1847,6 +1927,37 @@ REGISTER_TEST_ALL_TYPES(NodeRewrite_FusedConv2D_Negative1); REGISTER_TEST(NodeRewrite_FusedConv2D_Negative2, DT_DOUBLE, DoubleInput); #undef REGISTER_TEST +// Rewrite test for _FusedDepthwiseConv2dNativeOp with unsupported type +#define REGISTER_TEST(NAME, T, INPUT) \ + TEST_F(MklLayoutPassTest, NAME##_##T) { \ + InitGraph( \ + "node { name: 'A' op: '" #INPUT "'}" \ + "node { name: 'B' op: '" #INPUT "'}" \ + "node { name: 'C' op: '" #INPUT "'}" \ + "node { name: 'D' op: '_FusedDepthwiseConv2dNative'" \ + " attr { key: 'T' value { type:" #T "} }" \ + " attr { key: 'num_args' value { i: 1 } }" \ + " attr { key: 'data_format' value { s: 'NCHW' } }" \ + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} " \ + "} }" \ + " attr { key: 'padding' value { s: 'SAME' } }" \ + " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} " \ + "} }" \ + " attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" \ + " attr { key: 'epsilon' value { f: 0.001 }}" \ + " input: ['A', 'B', 'C']}" \ + "node { name: 'E' op: 'Zeta'" \ + "attr { key: 'T' value { type: " #T "} }" \ + " input: ['D', 'C'] }"); \ + EXPECT_EQ(DoMklLayoutOptimizationPass(), \ + "A(" #INPUT ");B(" #INPUT ");C(" #INPUT ");" \ + "D(_FusedDepthwiseConv2dNative);" \ + "E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); \ +} +REGISTER_TEST(NodeRewrite_FusedDepthwiseConv2dNative_Negative2, + DT_DOUBLE, DoubleInput); +#undef REGISTER_TEST + // Test set: _FusedMatMul -> MklFusedMatMul rewrite tests #define REGISTER_TEST(NAME, T, INPUT) \ TEST_F(MklLayoutPassTest, NAME##_##T) { \ @@ -4240,6 +4351,33 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Positive) { "_MklFusedConv2D")); } +// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is a constant. +TEST_F(MklLayoutPassTest, + FusedDepthwiseConv2dNativeWithBias_FilterCaching_Positive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Const'" // Filter + " attr { key: 'dtype' value { type: DT_FLOAT } }" + " attr { key: 'value' value { " + " tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } " + " int_val: 0 } } } }" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: '_FusedDepthwiseConv2dNative'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'num_args' value { i: 1 } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" + " attr { key: 'epsilon' value { f: 0.001 }}" + " input: ['A', 'B', 'C']}" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'C'] }"); + EXPECT_TRUE(DoMklLayoutOptimizationPassGetAttrVal( + "is_filter_const", "_MklFusedDepthwiseConv2dNative")); +} + // _FusedConv2D + BiasAdd fusion where filter is NOT a constant. TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) { InitGraph( @@ -4262,6 +4400,28 @@ TEST_F(MklLayoutPassTest, FusedConv2DWithBias_FilterCaching_Negative) { "_MklFusedConv2D")); } +// _FusedDepthwiseConv2dNative + BiasAdd fusion where filter is NOT a constant. +TEST_F(MklLayoutPassTest, + FusedDepthwiseConv2dNativeWithBias_FilterCaching_Negative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" // Filter + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: '_FusedDepthwiseConv2dNative'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'num_args' value { i: 1 } }" + " attr { key: 'data_format' value { s: 'NCHW' } }" + " attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'padding' value { s: 'SAME' } }" + " attr { key: 'dilations' value { list: {i: 1, i:1, i:1, i:1} } }" + " attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" + " attr { key: 'epsilon' value { f: 0.001 }}" + " input: ['A', 'B', 'C']}" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'C'] }"); + EXPECT_FALSE(DoMklLayoutOptimizationPassGetAttrVal( + "is_filter_const", "_MklFusedDepthwiseConv2dNative")); +} // Depthwise Conv2D op where filter is a constant. TEST_F(MklLayoutPassTest, DepthwiseConv2dNative_FilterCaching_Positive) { InitGraph( diff --git a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc index 120f252c933..7a6b4907bf4 100644 --- a/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/mkl_remapper_test.cc @@ -206,6 +206,94 @@ CREATE_CONV2DFUSION_ADD_BCAST_TEST(AddV2); #undef CREATE_CONV2DFUSION_ADD_BCAST_TEST #undef CREATE_CONV2DFUSION_TEST +TEST_F(MklRemapperTest, FuseDepthwiseConv2DWithBiasAndActivation) { + using ::tensorflow::ops::Placeholder; + + for (const string& activation : {"Relu", "Relu6", "Elu", "None"}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + + auto input_shape = Placeholder::Shape({8, 32, 32, 3}); + auto filter_shape = Placeholder::Shape({1, 1, 3, 1}); + auto bias_shape = Placeholder::Shape({3}); + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); + + std::vector strides = {1, 1, 1, 1}; + auto conv = ops::DepthwiseConv2dNative(s.WithOpName("depthwise_conv"), + input, filter, strides, "SAME"); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); + + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); + + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, bias_add)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, bias_add)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, bias_add)); + } + + DCHECK(activation == "None"); + return ops::Identity(fetch, bias_add); + }(); + + auto input_t = GenerateRandomTensor({8, 32, 32, 3}); + auto filter_t = GenerateRandomTensor({1, 1, 3, 1}); + auto bias_t = GenerateRandomTensor({3}); + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); + } + + Remapper optimizer(RewriterConfig::ON); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() != "bias_add" && node.name() != "activation") continue; + + EXPECT_EQ(node.op(), "_FusedDepthwiseConv2dNative"); + ASSERT_EQ(node.input_size(), 3); + EXPECT_EQ(node.input(0), "input"); + EXPECT_EQ(node.input(1), "filter"); + + EXPECT_EQ(node.attr().at("num_args").i(), 1); + EXPECT_EQ(node.input(2), "bias"); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + if (node.name() == "bias_add") { + ASSERT_EQ(fused_ops.size(), 1); + EXPECT_EQ(fused_ops[0], "BiasAdd"); + found++; + } + if (node.name() == "activation") { + ASSERT_EQ(fused_ops.size(), 2); + EXPECT_EQ(fused_ops[0], "BiasAdd"); + EXPECT_EQ(fused_ops[1], activation); + found++; + } + } + EXPECT_EQ(found, 1); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + ASSERT_EQ(tensors_expected.size(), 1); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-6); + } +} + } // namespace grappler } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 4386935ee3f..c9cce4daca1 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -57,6 +57,7 @@ namespace { constexpr char kFusedConv2D[] = "_FusedConv2D"; constexpr char kFusedMatMul[] = "_FusedMatMul"; +constexpr char kFusedDepthwiseConv2dNative[] = "_FusedDepthwiseConv2dNative"; constexpr char kFusedBatchNormEx[] = "_FusedBatchNormEx"; constexpr char kDataFormat[] = "data_format"; @@ -279,12 +280,24 @@ bool IsCpuCompatibleMatMul(const NodeDef* matmul) { return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul); } +bool IsCpuCompatibleDepthwiseConv2dNative(const NodeDef* dw_conv2d) { + DCHECK(IsDepthwiseConv2dNative(*dw_conv2d)) + << "Expected DepthwiseConv2dNative op"; + return NodeIsOnCpu(dw_conv2d) && IsCpuCompatibleDataType(dw_conv2d); +} + // Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU. template bool IsCpuCompatible(const RemapperContext& ctx, const Pattern& matched) { const NodeDef& node = ctx.graph_view.graph()->node(matched.contraction); if (IsConv2D(node)) { return IsCpuCompatibleConv2D(&node); + } else if (IsDepthwiseConv2dNative(node)) { +#ifdef INTEL_MKL + return IsCpuCompatibleDepthwiseConv2dNative(&node); +#else + return false; +#endif // INTEL_MKL } else if (IsMatMul(node)) { return IsCpuCompatibleMatMul(&node); } else { @@ -381,11 +394,12 @@ bool FindContractionWithBias(const RemapperContext& ctx, int node_index, const auto* contraction_node_view = regular_fanin_0.node_view(); const auto* contraction_node_def = contraction_node_view->node(); - bool is_conv2d_or_matmul = - IsConv2D(*contraction_node_def) || IsMatMul(*contraction_node_def); + // Conv2D, MatMul or DepthwiseConv2D + bool is_contraction = IsConv2D(*contraction_node_def) || + IsMatMul(*contraction_node_def) || + IsDepthwiseConv2dNative(*contraction_node_def); - if (!is_conv2d_or_matmul || - !HaveSameDataType(node_def, contraction_node_def) || + if (!is_contraction || !HaveSameDataType(node_def, contraction_node_def) || HasControlFaninOrFanout(*contraction_node_view) || !HasAtMostOneFanoutAtPort0(*contraction_node_view) || IsInPreserveSet(ctx, contraction_node_def)) @@ -902,6 +916,21 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) { (*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu"); } +void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d, + NodeDef* fused_dw_conv2d) { + DCHECK(IsDepthwiseConv2dNative(dw_conv2d)) + << "Input node must be a DepthwiseConv2dNative"; + + auto* attr = fused_dw_conv2d->mutable_attr(); + auto& src_attr = dw_conv2d.attr(); + + (*attr)["T"] = src_attr.at("T"); + (*attr)["strides"] = src_attr.at("strides"); + (*attr)["padding"] = src_attr.at("padding"); + (*attr)["dilations"] = src_attr.at("dilations"); + (*attr)["data_format"] = src_attr.at("data_format"); +} + void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm, NodeDef* fused_batch_norm_ex) { DCHECK(IsFusedBatchNorm(fused_batch_norm)) @@ -966,6 +995,9 @@ Status AddFusedContractionNode(RemapperContext* ctx, if (IsConv2D(contraction)) { fused_op.set_op(kFusedConv2D); CopyConv2DAttributes(contraction, &fused_op); + } else if (IsDepthwiseConv2dNative(contraction)) { + fused_op.set_op(kFusedDepthwiseConv2dNative); + CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op); } else if (IsMatMul(contraction)) { fused_op.set_op(kFusedMatMul); CopyMatMulAttributes(contraction, &fused_op); @@ -1010,6 +1042,9 @@ Status AddFusedContractionNode( if (IsConv2D(contraction)) { fused_op.set_op(kFusedConv2D); CopyConv2DAttributes(contraction, &fused_op); + } else if (IsDepthwiseConv2dNative(contraction)) { + fused_op.set_op(kFusedDepthwiseConv2dNative); + CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op); } else if (IsMatMul(contraction)) { fused_op.set_op(kFusedMatMul); CopyMatMulAttributes(contraction, &fused_op); @@ -1660,7 +1695,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, } #endif //! INTEL_MKL - // Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul} + // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd into the + // _Fused{Conv2D,DepthwiseConv2dNative,MatMul} ContractionWithBiasAdd contract_with_bias; if (allow_non_differentiable_rewrites && FindContractionWithBias(ctx, i, &contract_with_bias)) { @@ -1669,7 +1705,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } - // Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}. + // Remap {Conv2D,DepthwiseConv2D,MatMul}+BiasAdd+Activation into the + // _Fused{Conv2D,DepthwiseConv2dNative,MatMul}. ContractionWithBiasAddAndActivation contract_with_bias_and_activation; if (allow_non_differentiable_rewrites && FindContractionWithBiasAndActivation( diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 80a53ad277e..52bb1e404e1 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -1363,6 +1363,58 @@ class MklFusedConvOp virtual ~MklFusedConvOp() {} }; +template +class MklFusedDepthwiseConvOp + : public MklConvOp { + public: + explicit MklFusedDepthwiseConvOp(OpKernelConstruction* context) + : MklConvOp(context) { + // Since we came here through the registration of + // _MklFusedDepthwiseConv2dNative, get all + // information from 'fused_ops' and 'num_args' + std::vector fused_ops; + OP_REQUIRES_OK(context, context->GetAttr("fused_ops", &fused_ops)); + + int num_args; + OP_REQUIRES_OK(context, context->GetAttr("num_args", &num_args)); + OP_REQUIRES(context, !fused_ops.empty(), + errors::InvalidArgument( + "Fused DepthwiseConv2D must have at least one fused op.")); + + if (fused_ops == std::vector{"BiasAdd"}) { + this->set_fuse_biasadd(true); + } else if (fused_ops == std::vector{"BiasAdd", "Relu"}) { + this->set_fuse_biasadd(true); + this->set_fuse_activation(true, ALGORITHM::eltwise_relu); + } else if (fused_ops == std::vector{"BiasAdd", "Relu6"}) { + this->set_fuse_biasadd(true); + this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); + } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { + this->set_fuse_biasadd(true); + this->set_fuse_activation(true, ALGORITHM::eltwise_elu, 1.0); + } else { + OP_REQUIRES(context, false, + errors::Unimplemented("Fusion is not implemented: [", + absl::StrJoin(fused_ops, ","), "]")); + } + + OP_REQUIRES( + context, num_args == 1, + errors::InvalidArgument( + "Fused DepthwiseConv2D must have one extra argument: bias.")); + + if (pad_enabled) { + this->set_fuse_pad(true); + } + } + + virtual ~MklFusedDepthwiseConvOp() {} +}; + // We create new class for each version of Quantized Convolution and inherit // from the FP32 version of the base class template ("out_type"), NoOp); +REGISTER_KERNEL_BUILDER(Name("_FusedDepthwiseConv2dNative") + .Device(DEVICE_CPU) + .TypeConstraint("T"), + NoOp); + // Register templatized MKL kernels for non-fused and fused-versions of // QuantizedDepthwiseConv2D. REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D") @@ -2306,6 +2363,14 @@ REGISTER_KERNEL_BUILDER( MklQuantizedConv2DReluOp); +REGISTER_KERNEL_BUILDER( + Name("_MklFusedDepthwiseConv2dNative") + .Device(DEVICE_CPU) + .TypeConstraint("T") + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), + MklFusedDepthwiseConvOp); + // Register 2D operations #define REGISTER_MKL_CPU_2D(T) \ REGISTER_KERNEL_BUILDER( \ diff --git a/tensorflow/core/kernels/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl_fused_ops_test.cc index ff4f678e476..edd1201a09c 100644 --- a/tensorflow/core/kernels/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl_fused_ops_test.cc @@ -134,6 +134,7 @@ class CommonTestUtilities : public OpsTestBase { static void VerifyFusedTensorsClose(int depth, int image_width, int image_height, int image_batch_count, int filter_size, int filter_count, + int bias_size, const std::vector& fused_ops, const FusedGraphRunner& run_default, const FusedGraphRunner& run_fused) { @@ -145,7 +146,6 @@ class CommonTestUtilities : public OpsTestBase { Tensor filter(dtype, {filter_size, filter_size, depth, filter_count}); filter.flat() = filter.flat().template setRandom(); - const int bias_size = filter_count; Tensor bias(dtype, {bias_size}); bias.flat() = bias.flat().template setRandom(); @@ -321,9 +321,10 @@ class MklFusedConv2DOpTest : public OpsTestBase { out); }; + const int bias_size = filter_count; CommonTestUtilities::VerifyFusedTensorsClose( depth, image_width, image_height, image_batch_count, filter_size, - filter_count, fused_ops, run_default, run_fused); + filter_count, bias_size, fused_ops, run_default, run_fused); } }; @@ -449,6 +450,223 @@ REGISTER_TYPED_TEST_CASE_P( using MklFusedBiasAddDataTypes = ::testing::Types; INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedConv2DWithBiasOpTest, MklFusedBiasAddDataTypes); + +// Testing MKL's fused depthwise convolution ops +template +class MklFusedDepthwiseConv2DOpTest : public OpsTestBase { + protected: + static constexpr int kDepth = 3; + static constexpr int kImageWidth = 32; + static constexpr int kImageHeight = 32; + static constexpr int kImageBatchCount = 8; + + void RunDepthwiseConv2DUnfused(const Tensor& input_data, + const Tensor& filter_data, + const Tensor& bias_data, + const std::vector& fused_ops, + Tensor* output, int stride = 1) { + auto root = tensorflow::Scope::NewRootScope(); + auto input_data_op = + ops::Const(root.WithOpName("input"), Input::Initializer(input_data)); + Output next_op = ops::DepthwiseConv2dNative( + root.WithOpName("depthwise_conv"), input_data_op, + ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data)), + {1, stride, stride, 1}, "SAME"); + + string last_op = ""; + if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") != + fused_ops.end()) { + last_op = "with_bias"; + next_op = ops::BiasAdd( + root.WithOpName(last_op), next_op, + ops::Const(root.WithOpName("bias"), Input::Initializer(bias_data))); + } + + if (std::find(fused_ops.begin(), fused_ops.end(), "Relu") != + fused_ops.end()) { + last_op = "with_relu"; + next_op = ops::Relu(root.WithOpName(last_op), next_op); + } + + if (std::find(fused_ops.begin(), fused_ops.end(), "Relu6") != + fused_ops.end()) { + last_op = "with_relu6"; + next_op = ops::Relu6(root.WithOpName(last_op), next_op); + } + + if (std::find(fused_ops.begin(), fused_ops.end(), "Elu") != + fused_ops.end()) { + last_op = "with_elu"; + next_op = ops::Elu(root.WithOpName(last_op), next_op); + } + + CommonTestUtilities::RunAndFetch(root, last_op, output); + } + + void RunMklFusedDepthwiseConv2DOp(const Tensor& image, const Tensor& filter, + const std::vector& args, + const std::vector& fused_ops, + Tensor* output, int stride = 1) { + DataType dtype = DataTypeToEnum::v(); + int num_args = static_cast(args.size()); + + TF_EXPECT_OK(NodeDefBuilder("fused_depthwise_conv_op", + "_MklFusedDepthwiseConv2dNative") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Attr("T", dtype) + .Attr("num_args", num_args) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("fused_ops", fused_ops) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(image.shape(), image.flat()); + AddInputFromArray(filter.shape(), filter.flat()); + for (const Tensor& arg : args) + AddInputFromArray(arg.shape(), arg.flat()); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + for (const Tensor& arg : args) + AddInputFromArray(dummy_shape, dummy_tensor); + TF_ASSERT_OK(RunOpKernel()); + + // Compare output to expected results + const Tensor& output_tensor = *GetOutput(0); + // Index 2 will need to be changed if the number of outputs produced + // by MklDepthwiseConv2D change. + const Tensor& output_meta_tensor = *GetOutput(2); + CommonTestUtilities test_util; + test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, + output); + } + + // Verifies computing unfused ops in a graph is identical to + // FusedDepthwiseConv2D. + void VerifyFusedDepthwiseConv2D(int filter_size, int filter_count, + int bias_size, + const std::vector& fused_ops, + int depth = kDepth, + int image_width = kImageWidth, + int image_height = kImageHeight, + int image_batch_count = kImageBatchCount) { + const FusedGraphRunner run_default = + [this](const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, const std::vector& fused_ops, + Tensor* out) { + RunDepthwiseConv2DUnfused(input_data, filter_data, bias_data, + fused_ops, out); + }; + + const FusedGraphRunner run_fused = + [this](const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, const std::vector& fused_ops, + Tensor* out) { + std::vector fused_input = {bias_data}; + RunMklFusedDepthwiseConv2DOp(input_data, filter_data, fused_input, + fused_ops, out); + }; + + CommonTestUtilities::VerifyFusedTensorsClose( + depth, image_width, image_height, image_batch_count, filter_size, + filter_count, bias_size, fused_ops, run_default, run_fused); + } +}; + +template +class MklFusedDepthwiseConv2DWithBiasOpTest + : public MklFusedDepthwiseConv2DOpTest {}; + +TYPED_TEST_SUITE_P(MklFusedDepthwiseConv2DWithBiasOpTest); + +// -------------------------------------------------------------------------- // +// DepthwiseConv2D + BiasAdd + {Activation} // +// -------------------------------------------------------------------------- // + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution) { + const int kFilterSize = 1; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd"}); +} + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolution) { + const int kFilterSize = 3; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd"}); +} + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, + OneByOneConvolutionAndRelu) { + const int kFilterSize = 1; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd", "Relu"}); +} + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndRelu) { + const int kFilterSize = 3; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd", "Relu"}); +} + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, + OneByOneConvolutionAndRelu6) { + const int kFilterSize = 1; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd", "Relu6"}); +} + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, + SpatialConvolutionAndRelu6) { + const int kFilterSize = 3; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd", "Relu6"}); +} + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolutionAndElu) { + const int kFilterSize = 1; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd", "Elu"}); +} + +TYPED_TEST_P(MklFusedDepthwiseConv2DWithBiasOpTest, SpatialConvolutionAndElu) { + const int kFilterSize = 3; + const int kFilterCount = 1; + const int kBiasSize = 3; + this->VerifyFusedDepthwiseConv2D(kFilterSize, kFilterCount, kBiasSize, + {"BiasAdd", "Elu"}); +} + +REGISTER_TYPED_TEST_SUITE_P( + MklFusedDepthwiseConv2DWithBiasOpTest, OneByOneConvolution, + SpatialConvolution, OneByOneConvolutionAndRelu, SpatialConvolutionAndRelu, + OneByOneConvolutionAndRelu6, SpatialConvolutionAndRelu6, + OneByOneConvolutionAndElu, SpatialConvolutionAndElu); + +using MklFusedBiasAddDataTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedDepthwiseConv2DWithBiasOpTest, + MklFusedBiasAddDataTypes); + // Testing fusion of pad and convolution class FusedPadConvOpTest : public OpsTestBase { diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index 47b3745573b..a625fb64ed3 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -61,6 +61,30 @@ REGISTER_OP("_MklFusedConv2D") is expected to create these operators. )doc"); +REGISTER_OP("_MklFusedDepthwiseConv2dNative") + .Input("input: T") + .Input("filter: T") + .Input("args: num_args * T") + .Input("mkl_input: uint8") + .Input("mkl_filter: uint8") + .Input("mkl_args: num_args * uint8") + .Output("output: T") + .Output("filter_output: T") + .Output("mkl_output: uint8") + .Output("mkl_filter_output: uint8") + .Attr("T: {bfloat16, float}") + .Attr("num_args: int >= 0") + .Attr("strides: list(int)") + .Attr("is_filter_const: bool = false") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ------------------------------------ // + .Attr("epsilon: float = 0.0001") + // ---------------------------------------------------------------------- // + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); + REGISTER_OP("_MklFusedMatMul") .Input("a: T") .Input("b: T") diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 83260bfedc9..9200547cf45 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -596,6 +596,23 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter") return Status::OK(); }); +REGISTER_OP("_FusedDepthwiseConv2dNative") + .Input("input: T") + .Input("filter: T") + .Input("args: num_args * T") + .Output("output: T") + .Attr("T: {half, bfloat16, float, double}") + .Attr("num_args: int >= 0") + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Attr(GetConvnetDataFormatAttrString()) + .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ------------------------------------ // + .Attr("epsilon: float = 0.0001") + // ---------------------------------------------------------------------- // + .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); + // -------------------------------------------------------------------------- REGISTER_OP("Conv3D") .Input("input: T")