diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 02fec134915..12542947b01 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -211,6 +211,8 @@ bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { return false; } +bool IsElu(const NodeDef& node) { return node.op() == "Elu"; } + bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } bool IsEnter(const NodeDef& node) { @@ -412,6 +414,8 @@ bool IsReduction(const NodeDef& node) { bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; } +bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; } + bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; } bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index fec635bccf2..643ecd4e6eb 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -66,6 +66,7 @@ bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); bool IsDivNoNan(const NodeDef& node); bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing); +bool IsElu(const NodeDef& node); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); @@ -134,6 +135,7 @@ bool IsReciprocalGrad(const NodeDef& node); bool IsRecv(const NodeDef& node); bool IsReduction(const NodeDef& node); bool IsRelu(const NodeDef& node); +bool IsRelu6(const NodeDef& node); bool IsRelu6Grad(const NodeDef& node); bool IsReluGrad(const NodeDef& node); bool IsReshape(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 193772fcda2..50c93496b5f 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -78,16 +78,16 @@ struct Conv2DWithBiasAdd { const NodeDef* bias_add = nullptr; }; -// Conv2D node followed by a BiasAdd and Relu. -struct Conv2DWithBiasAddAndRelu { - Conv2DWithBiasAddAndRelu() = default; - Conv2DWithBiasAddAndRelu(const NodeDef* conv2d, const NodeDef* bias_add, - const NodeDef* relu) - : conv2d(conv2d), bias_add(bias_add), relu(relu) {} +// Conv2D node followed by a BiasAdd and Activation. +struct Conv2DWithBiasAddAndActivation { + Conv2DWithBiasAddAndActivation() = default; + Conv2DWithBiasAddAndActivation(const NodeDef* conv2d, const NodeDef* bias_add, + const NodeDef* activation) + : conv2d(conv2d), bias_add(bias_add), activation(activation) {} const NodeDef* conv2d = nullptr; const NodeDef* bias_add = nullptr; - const NodeDef* relu = nullptr; + const NodeDef* activation = nullptr; }; // Conv2D node followed by a Squeeze and BiasAdd. @@ -114,20 +114,21 @@ struct Conv2DWithBatchNorm { float epsilon = 0.0; }; -// Conv2D node followed by a FusedBatchNorm and Relu. -struct Conv2DWithBatchNormAndRelu { - Conv2DWithBatchNormAndRelu() = default; - Conv2DWithBatchNormAndRelu(const NodeDef* conv2d, - const NodeDef* fused_batch_norm, - const NodeDef* relu, float epsilon = 0.0) +// Conv2D node followed by a FusedBatchNorm and Activation. +struct Conv2DWithBatchNormAndActivation { + Conv2DWithBatchNormAndActivation() = default; + Conv2DWithBatchNormAndActivation(const NodeDef* conv2d, + const NodeDef* fused_batch_norm, + const NodeDef* activation, + float epsilon = 0.0) : conv2d(conv2d), fused_batch_norm(fused_batch_norm), - relu(relu), + activation(activation), epsilon(epsilon) {} const NodeDef* conv2d = nullptr; const NodeDef* fused_batch_norm = nullptr; - const NodeDef* relu = nullptr; + const NodeDef* activation = nullptr; float epsilon = 0.0; }; @@ -194,7 +195,7 @@ bool IsCpuCompatible(const Pattern& matched) { // Checks if we can rewrite a pattern to the `_FusedConv2D` on GPU device. bool IsGpuCompatible(const RemapperContext& ctx, - const Conv2DWithBiasAddAndRelu& matched) { + const Conv2DWithBiasAddAndActivation& matched) { const std::vector& input_props = ctx.graph_properties.GetInputProperties(matched.conv2d->name()); const TensorShapeProto& filter_shape = @@ -226,6 +227,10 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) { return IsCpuCompatible(matched) || IsGpuCompatible(ctx, matched); } +bool IsSupportedActivation(const NodeDef& node) { + return IsRelu(node) || IsRelu6(node) || IsElu(node); +} + bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add, Conv2DWithBiasAdd* matched, bool check_device_compatible = true) { @@ -259,31 +264,34 @@ bool FindConv2DWithBias(const RemapperContext& ctx, const NodeDef* bias_add, return true; } -bool FindConv2DWithBiasAndRelu(const RemapperContext& ctx, const NodeDef* relu, - Conv2DWithBiasAddAndRelu* matched) { +bool FindConv2DWithBiasAndActivation(const RemapperContext& ctx, + const NodeDef* activation, + Conv2DWithBiasAddAndActivation* matched) { if (!EigenSupportsContractionOutputKernel()) return false; - // Root of the pattern must be a Relu. - if (!relu || !IsRelu(*relu) || HasControlFaninOrFanout(ctx.graph_view, relu)) + // Root of the pattern must be an activation node. + if (!activation || !IsSupportedActivation(*activation) || + HasControlFaninOrFanout(ctx.graph_view, activation)) return false; - // And input to Relu must match Conv2DWithBiasAdd pattern. - const auto input_port = GraphView::InputPort(relu, 0); + // And input to the activation node must match Conv2DWithBiasAdd pattern. + const auto input_port = GraphView::InputPort(activation, 0); const auto bias_add = ctx.graph_view.GetRegularFanin(input_port); Conv2DWithBiasAdd base; if (!FindConv2DWithBias(ctx, bias_add.node, &base, /*check_device_compatible=*/false) || !HasSingleFanoutNode(ctx.graph_view, base.bias_add) || - !HaveSameDataType(relu, base.bias_add) || + !HaveSameDataType(activation, base.bias_add) || IsInPreserveSet(ctx, base.bias_add)) return false; // Check that data type and data format are supported on assigned device. - const Conv2DWithBiasAddAndRelu pattern{base.conv2d, base.bias_add, relu}; + const Conv2DWithBiasAddAndActivation pattern{base.conv2d, base.bias_add, + activation}; if (!IsDeviceCompatible(ctx, pattern)) return false; - // We successfully found a Conv2D+BiasAdd+Relu pattern. + // We successfully found a Conv2D+BiasAdd+Activation pattern. *matched = pattern; return true; @@ -387,16 +395,17 @@ bool FindConv2DWithBatchNorm(const RemapperContext& ctx, return true; } -bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx, - const NodeDef* node, - Conv2DWithBatchNormAndRelu* matched) { +bool FindConv2DWithBatchNormAndActivation( + const RemapperContext& ctx, const NodeDef* node, + Conv2DWithBatchNormAndActivation* matched) { if (!EigenSupportsContractionOutputKernel()) return false; - // Root of the pattern must be a Relu. - if (!node || !IsRelu(*node) || HasControlFaninOrFanout(ctx.graph_view, node)) + // Root of the pattern must be an activation node. + if (!node || !IsSupportedActivation(*node) || + HasControlFaninOrFanout(ctx.graph_view, node)) return false; - // And input to Relu must match Conv2DWithBatchNorm pattern. + // And input to the activation node must match Conv2DWithBatchNorm pattern. const auto input_port = GraphView::InputPort(node, 0); const auto batch_norm = ctx.graph_view.GetRegularFanin(input_port); @@ -407,10 +416,10 @@ bool FindConv2DWithBatchNormAndRelu(const RemapperContext& ctx, IsInPreserveSet(ctx, base.fused_batch_norm)) return false; - // We successfully found a Conv2D+FusedBatchNorm+Relu pattern. + // We successfully found a Conv2D+FusedBatchNorm+Activation pattern. matched->conv2d = base.conv2d; matched->fused_batch_norm = base.fused_batch_norm; - matched->relu = node; + matched->activation = node; matched->epsilon = base.epsilon; return true; @@ -511,19 +520,19 @@ void AddFusedConv2DNode( } void AddFusedConv2DNode( - const RemapperContext& ctx, const Conv2DWithBiasAddAndRelu& matched, + const RemapperContext& ctx, const Conv2DWithBiasAddAndActivation& matched, GraphDef* optimized_graph, absl::flat_hash_set* invalidated_nodes) { DCHECK(IsDeviceCompatible(ctx, matched)) << "Unsupported fused Conv2D pattern"; - VLOG(2) << "Fuse Conv2D with BiasAdd and Relu: " - << " relu=" << matched.relu->name() + VLOG(2) << "Fuse Conv2D with BiasAdd and " << matched.activation->op() << ":" + << " activation=" << matched.activation->name() << " bias_add=" << matched.bias_add->name() << " conv2d=" << matched.conv2d->name(); NodeDef* fused_conv2d = optimized_graph->add_node(); - fused_conv2d->set_name(matched.relu->name()); + fused_conv2d->set_name(matched.activation->name()); fused_conv2d->set_op(kFusedConv2D); fused_conv2d->set_device(matched.conv2d->device()); fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input @@ -531,9 +540,9 @@ void AddFusedConv2DNode( fused_conv2d->add_input(matched.bias_add->input(1)); // 2: bias CopyConv2DAttributes(matched.conv2d, fused_conv2d); - SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", "Relu"}); + SetFusedConv2DAttributes(fused_conv2d, {"BiasAdd", matched.activation->op()}); - invalidated_nodes->insert(matched.relu); + invalidated_nodes->insert(matched.activation); invalidated_nodes->insert(matched.bias_add); invalidated_nodes->insert(matched.conv2d); } @@ -601,15 +610,15 @@ void AddFusedConv2DNode( } void AddFusedConv2DNode( - const Conv2DWithBatchNormAndRelu& matched, GraphDef* optimized_graph, + const Conv2DWithBatchNormAndActivation& matched, GraphDef* optimized_graph, absl::flat_hash_set* invalidated_nodes) { - VLOG(2) << "Fuse Conv2D with BatchNorm and Relu: relu=" - << matched.relu->name() + VLOG(2) << "Fuse Conv2D with BatchNorm and " << matched.activation->op() + << ": activation=" << matched.activation->name() << " batch_norm=" << matched.fused_batch_norm->name() << " conv2d=" << matched.conv2d->name(); NodeDef* fused_conv2d = optimized_graph->add_node(); - fused_conv2d->set_name(matched.relu->name()); + fused_conv2d->set_name(matched.activation->name()); fused_conv2d->set_op(kFusedConv2D); fused_conv2d->set_device(matched.conv2d->device()); fused_conv2d->add_input(matched.conv2d->input(0)); // 0: input @@ -620,10 +629,11 @@ void AddFusedConv2DNode( fused_conv2d->add_input(matched.fused_batch_norm->input(4)); // 5: variance CopyConv2DAttributes(matched.conv2d, fused_conv2d); - SetFusedConv2DAttributes(fused_conv2d, {"FusedBatchNorm", "Relu"}, + SetFusedConv2DAttributes(fused_conv2d, + {"FusedBatchNorm", matched.activation->op()}, /*num_args=*/4, /*epsilon=*/matched.epsilon); - invalidated_nodes->insert(matched.relu); + invalidated_nodes->insert(matched.activation); invalidated_nodes->insert(matched.fused_batch_norm); invalidated_nodes->insert(matched.conv2d); } @@ -778,9 +788,9 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, // clang-format off FusedBatchNorm fused_batch_norm; Conv2DWithBiasAdd conv2d_with_bias; - Conv2DWithBiasAddAndRelu conv2d_with_bias_and_relu; + Conv2DWithBiasAddAndActivation conv2d_with_bias_and_activation; Conv2DWithBatchNorm conv2d_with_batch_norm; - Conv2DWithBatchNormAndRelu conv2d_with_batch_norm_and_relu; + Conv2DWithBatchNormAndActivation conv2d_with_batch_norm_and_activation; Conv2DWithSqueezeAndBiasAdd conv2d_with_squeeze_and_bias; // clang-format on @@ -795,7 +805,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, RemapperContext ctx(topo_sorted_item); // Skip nodes that were invalidated by a remapper, e.g. do not process BiasAdd - // and Relu nodes that were fused into a Conv2D node. + // and Activation nodes that were fused into a Conv2D node. absl::flat_hash_set invalidated_nodes; optimized_graph->mutable_node()->Reserve(topo_sorted_item.graph.node_size()); @@ -810,9 +820,10 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, continue; } - // Remap Conv2D+BiasAdd+Relu into the _FusedConv2D. - if (FindConv2DWithBiasAndRelu(ctx, &node, &conv2d_with_bias_and_relu)) { - AddFusedConv2DNode(ctx, conv2d_with_bias_and_relu, optimized_graph, + // Remap Conv2D+BiasAdd+Activation into the _FusedConv2D. + if (FindConv2DWithBiasAndActivation(ctx, &node, + &conv2d_with_bias_and_activation)) { + AddFusedConv2DNode(ctx, conv2d_with_bias_and_activation, optimized_graph, &invalidated_nodes); continue; } @@ -835,10 +846,10 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, continue; } - // Remap Conv2D+FusedBatchNorm+Relu into the _FusedConv2D; - if (FindConv2DWithBatchNormAndRelu(ctx, &node, - &conv2d_with_batch_norm_and_relu)) { - AddFusedConv2DNode(conv2d_with_batch_norm_and_relu, optimized_graph, + // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D; + if (FindConv2DWithBatchNormAndActivation( + ctx, &node, &conv2d_with_batch_norm_and_activation)) { + AddFusedConv2DNode(conv2d_with_batch_norm_and_activation, optimized_graph, &invalidated_nodes); continue; } diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index ffc242decc7..b59e97a5c40 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/remapper.h" + #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/devices.h" @@ -97,7 +98,7 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) { EXPECT_EQ(1, tensors_expected.size()); auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-5); } } @@ -164,69 +165,85 @@ TEST_F(RemapperTest, FuseConv2DWithBias) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(RemapperTest, FuseConv2DWithBiasAndRelu) { +TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) { if (!EigenSupportsContractionOutputKernel()) return; using ::tensorflow::ops::Placeholder; - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + for (const string& activation : {"Relu", "Relu6", "Elu"}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - auto input_shape = Placeholder::Shape({8, 32, 32, 3}); - auto filter_shape = Placeholder::Shape({1, 1, 3, 128}); - auto bias_shape = Placeholder::Shape({128}); + auto input_shape = Placeholder::Shape({8, 32, 32, 3}); + auto filter_shape = Placeholder::Shape({1, 1, 3, 128}); + auto bias_shape = Placeholder::Shape({128}); - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); - auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); + auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); - std::vector strides = {1, 1, 1, 1}; - auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); - auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - auto relu = ops::Relu(s.WithOpName("relu"), bias_add); - auto fetch = ops::Identity(s.WithOpName("fetch"), relu); + std::vector strides = {1, 1, 1, 1}; + auto conv = + ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); + auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), conv, bias); - auto input_t = GenerateRandomTensor({8, 32, 32, 3}); - auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); - auto bias_t = GenerateRandomTensor({128}); + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, bias_add)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, bias_add)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, bias_add)); + } - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } + return ops::Identity(fetch, bias); + }(); - Remapper optimizer(RewriterConfig::ON); - GraphDef output; - TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + auto input_t = GenerateRandomTensor({8, 32, 32, 3}); + auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); + auto bias_t = GenerateRandomTensor({128}); - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "relu") { - EXPECT_EQ("_FusedConv2D", node.op()); - EXPECT_EQ("input", node.input(0)); - EXPECT_EQ("filter", node.input(1)); + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, {"filter", filter_t}, {"bias", bias_t}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); - EXPECT_EQ(1, node.attr().at("num_args").i()); - EXPECT_EQ("bias", node.input(2)); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - ASSERT_EQ(2, fused_ops.size()); - EXPECT_EQ("BiasAdd", fused_ops[0]); - EXPECT_EQ("Relu", fused_ops[1]); - found++; + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); } - } - EXPECT_EQ(1, found); - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - EXPECT_EQ(1, tensors_expected.size()); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + Remapper optimizer(RewriterConfig::ON); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "activation") { + EXPECT_EQ("_FusedConv2D", node.op()); + EXPECT_EQ("input", node.input(0)); + EXPECT_EQ("filter", node.input(1)); + + EXPECT_EQ(1, node.attr().at("num_args").i()); + EXPECT_EQ("bias", node.input(2)); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + ASSERT_EQ(2, fused_ops.size()); + EXPECT_EQ("BiasAdd", fused_ops[0]); + EXPECT_EQ(activation, fused_ops[1]); + found++; + } + } + EXPECT_EQ(1, found); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + } } TEST_F(RemapperTest, FuseConv2DWithBatchNorm) { @@ -306,83 +323,100 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNorm) { test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); } -TEST_F(RemapperTest, FuseConv2DWithBatchNormAndRelu) { +TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) { if (!EigenSupportsContractionOutputKernel()) return; using ops::Placeholder; - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + for (const string& activation : {"Relu", "Relu6", "Elu"}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3}); - auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128}); - auto scale_shape = ops::Placeholder::Shape({128}); + auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3}); + auto filter_shape = ops::Placeholder::Shape({1, 1, 3, 128}); + auto scale_shape = ops::Placeholder::Shape({128}); - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); - auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, scale_shape); - auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, scale_shape); - auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, scale_shape); - auto variance = Placeholder(s.WithOpName("variance"), DT_FLOAT, scale_shape); + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); + auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); + auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, scale_shape); + auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, scale_shape); + auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, scale_shape); + auto variance = + Placeholder(s.WithOpName("variance"), DT_FLOAT, scale_shape); - std::vector strides = {1, 1, 1, 1}; - auto conv = ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); - ops::FusedBatchNorm::Attrs attrs; - attrs = attrs.IsTraining(false); - auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv, scale, - offset, mean, variance, attrs); - auto relu = ops::Relu(s.WithOpName("relu"), batch_norm.y); - auto fetch = ops::Identity(s.WithOpName("fetch"), relu); + std::vector strides = {1, 1, 1, 1}; + auto conv = + ops::Conv2D(s.WithOpName("conv"), input, filter, strides, "SAME"); + ops::FusedBatchNorm::Attrs attrs; + attrs = attrs.IsTraining(false); + auto batch_norm = ops::FusedBatchNorm(s.WithOpName("batch_norm"), conv, + scale, offset, mean, variance, attrs); - auto input_t = GenerateRandomTensor({8, 32, 32, 3}); - auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); - auto scale_t = GenerateRandomTensor({128}); - auto offset_t = GenerateRandomTensor({128}); - auto mean_t = GenerateRandomTensor({128}); - auto variance_t = GenerateRandomTensor({128}); + ops::Identity fetch = [&]() -> ops::Identity { + auto activate = s.WithOpName("activation"); + auto fetch = s.WithOpName("fetch"); - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, {"filter", filter_t}, - {"scale", scale_t}, {"offset", offset_t}, - {"mean", mean_t}, {"variance", variance_t}}; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); + if (activation == "Relu") { + return ops::Identity(fetch, ops::Relu(activate, batch_norm.y)); + } else if (activation == "Relu6") { + return ops::Identity(fetch, ops::Relu6(activate, batch_norm.y)); + } else if (activation == "Elu") { + return ops::Identity(fetch, ops::Elu(activate, batch_norm.y)); + } - // Place all nodes on CPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:CPU:0"); - } + return ops::Identity(fetch, batch_norm.y); + }(); - Remapper optimizer(RewriterConfig::ON); - GraphDef output; - TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + auto input_t = GenerateRandomTensor({8, 32, 32, 3}); + auto filter_t = GenerateRandomTensor({1, 1, 3, 128}); + auto scale_t = GenerateRandomTensor({128}); + auto offset_t = GenerateRandomTensor({128}); + auto mean_t = GenerateRandomTensor({128}); + auto variance_t = GenerateRandomTensor({128}); - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "relu") { - EXPECT_EQ("_FusedConv2D", node.op()); - EXPECT_EQ("input", node.input(0)); - EXPECT_EQ("filter", node.input(1)); + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, {"filter", filter_t}, + {"scale", scale_t}, {"offset", offset_t}, + {"mean", mean_t}, {"variance", variance_t}}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); - EXPECT_EQ(4, node.attr().at("num_args").i()); - EXPECT_EQ("scale", node.input(2)); - EXPECT_EQ("offset", node.input(3)); - EXPECT_EQ("mean", node.input(4)); - EXPECT_EQ("variance", node.input(5)); - - const auto fused_ops = node.attr().at("fused_ops").list().s(); - EXPECT_EQ(2, fused_ops.size()); - EXPECT_EQ("FusedBatchNorm", fused_ops[0]); - EXPECT_EQ("Relu", fused_ops[1]); - found++; + // Place all nodes on CPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:CPU:0"); } - } - EXPECT_EQ(1, found); - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - EXPECT_EQ(1, tensors_expected.size()); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + Remapper optimizer(RewriterConfig::ON); + GraphDef output; + TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "activation") { + EXPECT_EQ("_FusedConv2D", node.op()); + EXPECT_EQ("input", node.input(0)); + EXPECT_EQ("filter", node.input(1)); + + EXPECT_EQ(4, node.attr().at("num_args").i()); + EXPECT_EQ("scale", node.input(2)); + EXPECT_EQ("offset", node.input(3)); + EXPECT_EQ("mean", node.input(4)); + EXPECT_EQ("variance", node.input(5)); + + const auto fused_ops = node.attr().at("fused_ops").list().s(); + EXPECT_EQ(2, fused_ops.size()); + EXPECT_EQ("FusedBatchNorm", fused_ops[0]); + EXPECT_EQ(activation, fused_ops[1]); + found++; + } + } + EXPECT_EQ(1, found); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + EXPECT_EQ(1, tensors_expected.size()); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6); + } } TEST_F(RemapperTest, FuseConv2DWithSqueezeAndBias) { diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 98dd1e5cefe..47397b589f0 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -88,6 +88,7 @@ GrapplerTest::GrapplerTest() { cfg->set_layout_optimizer(RewriterConfig::OFF); cfg->set_loop_optimization(RewriterConfig::OFF); cfg->set_pin_to_host_optimization(RewriterConfig::OFF); + cfg->set_remapping(RewriterConfig::OFF); } std::vector GrapplerTest::EvaluateNodes(