From d73778d2054f187ce6b6e5a7915d78aac6674e62 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Mon, 17 Jun 2019 10:36:15 -0700 Subject: [PATCH] Add a remapper pattern for FusedBatchNorm[is_training=False]. tf_cnn_benchmarks --forward_only=true --model=resnet50: fp32: 1010 images/sec -> 1145 images/sec fp16: 1990 images/sec -> 3010 images/sec PiperOrigin-RevId: 253610151 --- .../core/grappler/optimizers/remapper.cc | 39 +- .../core/grappler/optimizers/remapper_test.cc | 356 ++++++++++-------- 2 files changed, 215 insertions(+), 180 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index f826774e6ae..e34a8f9ac27 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -761,8 +761,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training) .ok()) return false; - // TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel. - if (!is_training) return false; // In training mode we rely on cuDNN for computing FusedBatchNorm with side // inputs and activation, and it has its own limitations. In inference mode @@ -796,7 +794,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index, !HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U")) return false; - // Check that only one node consumes the output of a FusedBatchNorm. + // Check that only one node consumes the 0-th output of a FusedBatchNorm. if (HasControlFaninOrFanout(fused_batch_norm) || !HasAtMostOneFanoutAtPort0(fused_batch_norm) || IsInPreserveSet(ctx, fused_batch_norm_node_def)) @@ -1216,11 +1214,14 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx, const NodeDef& activation = graph->node(matched.activation); VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:" - << " side_input=" + << " activation=" << activation.name() << " side_input=" << (matched.side_input != kMissingIndex ? graph->node(matched.side_input).name() : "") - << " activation=" << activation.name() + << " invalidated=" + << (matched.invalidated != kMissingIndex + ? graph->node(matched.invalidated).name() + : "") << " fused_batch_norm=" << fused_batch_norm.name(); // Replace FusedBatchNorm with _FusedBatchNormEx + + . @@ -1512,23 +1513,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) { Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { - // Supported graph patterns. - // clang-format off - FusedBatchNorm fused_batch_norm; - FusedBatchNormEx fused_batch_norm_ex; - ContractionWithBiasAdd contract_with_bias; - ContractionWithBiasAddAndActivation contract_with_bias_and_activation; -#ifndef INTEL_MKL - ContractionWithBatchNorm contract_with_batch_norm; - ContractionWithBatchNormAndActivation contract_with_batch_norm_and_activation; - ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias; -#endif // !INTEL_MKL -#ifdef INTEL_MKL - ContractionWithBiasAddAndAdd contract_with_bias_and_add; - ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation; -#endif // INTEL_MKL - // clang-format on - GrapplerItem mutable_item = item; Status status; RemapperContext ctx(&mutable_item, &status); @@ -1556,6 +1540,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, } #ifdef INTEL_MKL + ContractionWithBiasAddAndAdd contract_with_bias_and_add; + ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation; + if (!item.optimization_options().is_eager_mode) { // Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D. if (FindContractionWithBiasAndAddActivation( @@ -1578,6 +1565,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, #endif //! INTEL_MKL // Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul} + ContractionWithBiasAdd contract_with_bias; if (allow_non_differentiable_rewrites && FindContractionWithBias(ctx, i, &contract_with_bias)) { TF_RETURN_IF_ERROR(AddFusedContractionNode( @@ -1586,6 +1574,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, } // Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}. + ContractionWithBiasAddAndActivation contract_with_bias_and_activation; if (allow_non_differentiable_rewrites && FindContractionWithBiasAndActivation( ctx, i, &contract_with_bias_and_activation)) { @@ -1603,6 +1592,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, // Remove this once TF-MKL supports _FusedConv2D with these operations. #ifndef INTEL_MKL // Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze. + ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias; if (allow_non_differentiable_rewrites && FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) { TF_RETURN_IF_ERROR( @@ -1612,6 +1602,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, } // Remap Conv2D+FusedBatchNorm into the _FusedConv2D; + ContractionWithBatchNorm contract_with_batch_norm; if (allow_non_differentiable_rewrites && FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) { TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm, @@ -1621,6 +1612,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, } // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D; + ContractionWithBatchNormAndActivation + contract_with_batch_norm_and_activation; if (allow_non_differentiable_rewrites && FindConv2DWithBatchNormAndActivation( ctx, i, &contract_with_batch_norm_and_activation)) { @@ -1644,6 +1637,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, } // Remap FusedBatchNorm++ into the _FusedBatchNormEx. + FusedBatchNormEx fused_batch_norm_ex; if (allow_non_differentiable_rewrites && FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) { TF_RETURN_IF_ERROR(AddFusedBatchNormExNode( @@ -1653,6 +1647,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, // During inference, most of the inputs to FusedBatchNorm are constant, and // we can therefore replace the op with a much cheaper set of primitives. + FusedBatchNorm fused_batch_norm; if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) { TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm)); continue; diff --git a/tensorflow/core/grappler/optimizers/remapper_test.cc b/tensorflow/core/grappler/optimizers/remapper_test.cc index 2e9f937928f..b0a1c9097fe 100644 --- a/tensorflow/core/grappler/optimizers/remapper_test.cc +++ b/tensorflow/core/grappler/optimizers/remapper_test.cc @@ -112,185 +112,225 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) { ASSERT_EQ(tensors_expected.size(), 1); auto tensors = EvaluateNodes(output, item.fetch); ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-5); + test::ExpectTensorNear(tensors[0], tensors_expected[0], 1e-3); } } TEST_F(RemapperTest, FuseBatchNormWithRelu) { using ::tensorflow::ops::Placeholder; + for (bool is_training : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + #if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) - LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires " - "CUDNN_VERSION >= 7402."; -#else - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24}); - auto channels_shape = ops::Placeholder::Shape({24}); - - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF); - auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape); - auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape); - auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape); - auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape); - - float epsilon = 0.1f; - auto fbn = ops::FusedBatchNormV3( - s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var, - ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat( - "NHWC")); - auto relu = ops::Relu(s.WithOpName("relu"), fbn.y); - auto fetch = ops::Identity(s.WithOpName("fetch"), relu); - - auto input_t = GenerateRandomTensor({2, 8, 8, 24}); - auto scale_t = GenerateRandomTensor({24}); - auto offset_t = GenerateRandomTensor({24}); - auto mean_t = GenerateRandomTensor({0}); // empty for training - auto var_t = GenerateRandomTensor({0}); // empty for training - - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, - {"scale", scale_t}, - {"offset", offset_t}, - {"mean", mean_t}, - {"var", var_t}}; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - - // Place all nodes on GPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:GPU:0"); - } - - Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "relu") { - EXPECT_EQ(node.op(), "Identity"); - ASSERT_EQ(node.input_size(), 1); - EXPECT_EQ(node.input(0), "fused_batch_norm"); - found++; + if (is_training) { + LOG(INFO) << "Skip FuseBatchNormWithRelu" + << "[is_training=" << is_training << "] " + << "test. It requires CUDNN_VERSION >= 7402."; + continue; } - if (node.name() == "fused_batch_norm") { - EXPECT_EQ(node.op(), "_FusedBatchNormEx"); - ASSERT_EQ(node.input_size(), 5); - EXPECT_EQ(node.input(0), "input_cast"); - EXPECT_EQ(node.input(1), "scale"); - EXPECT_EQ(node.input(2), "offset"); - EXPECT_EQ(node.input(3), "mean"); - EXPECT_EQ(node.input(4), "var"); +#endif - auto attr = node.attr(); - EXPECT_EQ(attr["num_side_inputs"].i(), 0); - EXPECT_EQ(attr["activation_mode"].s(), "Relu"); - found++; +#if !defined(GOOGLE_CUDA) + if (!is_training) { + LOG(INFO) << "Skip FuseBatchNormWithRelu" + << "[is_training=" << is_training << "]"; + continue; + } +#endif + + const int num_channels = 24; + + TensorShape channel_shape({num_channels}); + TensorShape empty_shape({0}); + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape({2, 8, 8, num_channels})); + auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF); + auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT); + auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT); + auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT); + auto var = Placeholder(s.WithOpName("var"), DT_FLOAT); + + float epsilon = 0.1f; + auto fbn = ops::FusedBatchNormV3( + s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var, + ops::FusedBatchNormV3::IsTraining(is_training) + .Epsilon(epsilon) + .DataFormat("NHWC")); + auto relu = ops::Relu(s.WithOpName("relu"), fbn.y); + auto fetch = ops::Identity(s.WithOpName("fetch"), relu); + + auto input_t = GenerateRandomTensor({2, 8, 8, num_channels}); + auto scale_t = GenerateRandomTensor(channel_shape); + auto offset_t = GenerateRandomTensor(channel_shape); + auto mean_t = GenerateRandomTensor(is_training ? empty_shape + : channel_shape); + auto var_t = GenerateRandomTensor(is_training ? empty_shape + : channel_shape); + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, + {"scale", scale_t}, + {"offset", offset_t}, + {"mean", mean_t}, + {"var", var_t}}; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on GPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:GPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "relu") { + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 1); + EXPECT_EQ(node.input(0), "fused_batch_norm"); + found++; + } + if (node.name() == "fused_batch_norm") { + EXPECT_EQ(node.op(), "_FusedBatchNormEx"); + ASSERT_EQ(node.input_size(), 5); + EXPECT_EQ(node.input(0), "input_cast"); + EXPECT_EQ(node.input(1), "scale"); + EXPECT_EQ(node.input(2), "offset"); + EXPECT_EQ(node.input(3), "mean"); + EXPECT_EQ(node.input(4), "var"); + + auto attr = node.attr(); + EXPECT_EQ(attr["num_side_inputs"].i(), 0); + EXPECT_EQ(attr["activation_mode"].s(), "Relu"); + found++; + } + } + EXPECT_EQ(found, 2); + + if (GetNumAvailableGPUs() > 0) { + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + ASSERT_EQ(tensors_expected.size(), 1); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + ASSERT_EQ(tensors.size(), 1); + test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2); } } - EXPECT_EQ(found, 2); - - if (GetNumAvailableGPUs() > 0) { - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - ASSERT_EQ(tensors_expected.size(), 1); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - ASSERT_EQ(tensors.size(), 1); - test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2); - } -#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) } TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) { using ::tensorflow::ops::Placeholder; + for (bool is_training : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + #if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) - LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires " - "CUDNN_VERSION >= 7402."; -#else - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24}); - auto channels_shape = ops::Placeholder::Shape({24}); - - auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); - auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF); - auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape); - auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape); - auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape); - auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape); - auto side_input = - Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape); - auto side_input_cast = - ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF); - - float epsilon = 0.1f; - auto fbn = ops::FusedBatchNormV3( - s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var, - ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat( - "NHWC")); - auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast); - auto relu = ops::Relu(s.WithOpName("relu"), add); - auto fetch = ops::Identity(s.WithOpName("fetch"), relu); - - auto input_t = GenerateRandomTensor({2, 8, 8, 24}); - auto scale_t = GenerateRandomTensor({24}); - auto offset_t = GenerateRandomTensor({24}); - auto mean_t = GenerateRandomTensor({0}); // empty for training - auto var_t = GenerateRandomTensor({0}); // empty for training - auto side_input_t = GenerateRandomTensor({2, 8, 8, 24}); - - GrapplerItem item; - item.fetch = {"fetch"}; - item.feed = {{"input", input_t}, {"scale", scale_t}, - {"offset", offset_t}, {"mean", mean_t}, - {"var", var_t}, {"side_input", side_input_t}}; - TF_ASSERT_OK(s.ToGraphDef(&item.graph)); - - // Place all nodes on GPU. - for (int i = 0; i < item.graph.node_size(); ++i) { - item.graph.mutable_node(i)->set_device("/device:GPU:0"); - } - - Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape - GraphDef output; - TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); - - int found = 0; - for (const NodeDef& node : output.node()) { - if (node.name() == "relu") { - EXPECT_EQ(node.op(), "Identity"); - ASSERT_EQ(node.input_size(), 1); - EXPECT_EQ(node.input(0), "fused_batch_norm"); - found++; + if (is_training) { + LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu" + << "[is_training=" << is_training << "] " + << "test. It requires CUDNN_VERSION >= 7402."; + continue; } - if (node.name() == "fused_batch_norm") { - EXPECT_EQ(node.op(), "_FusedBatchNormEx"); - ASSERT_EQ(node.input_size(), 6); - EXPECT_EQ(node.input(0), "input_cast"); - EXPECT_EQ(node.input(1), "scale"); - EXPECT_EQ(node.input(2), "offset"); - EXPECT_EQ(node.input(3), "mean"); - EXPECT_EQ(node.input(4), "var"); - EXPECT_EQ(node.input(5), "side_input_cast"); +#endif - auto attr = node.attr(); - EXPECT_EQ(attr["num_side_inputs"].i(), 1); - EXPECT_EQ(attr["activation_mode"].s(), "Relu"); - found++; +#if !defined(GOOGLE_CUDA) + if (!is_training) { + LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu" + << "[is_training=" << is_training << "]"; + continue; + } +#endif + + const int num_channels = 24; + + TensorShape input_shape({2, 8, 8, num_channels}); + TensorShape channel_shape({num_channels}); + TensorShape empty_shape({0}); + + auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape(input_shape)); + auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF); + auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT); + auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT); + auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT); + auto var = Placeholder(s.WithOpName("var"), DT_FLOAT); + auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT); + auto side_input_cast = + ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF); + + float epsilon = 0.1f; + auto fbn = ops::FusedBatchNormV3( + s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var, + ops::FusedBatchNormV3::IsTraining(is_training) + .Epsilon(epsilon) + .DataFormat("NHWC")); + auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast); + auto relu = ops::Relu(s.WithOpName("relu"), add); + auto fetch = ops::Identity(s.WithOpName("fetch"), relu); + + auto input_t = GenerateRandomTensor(input_shape); + auto scale_t = GenerateRandomTensor(channel_shape); + auto offset_t = GenerateRandomTensor(channel_shape); + auto mean_t = GenerateRandomTensor(is_training ? empty_shape + : channel_shape); + auto var_t = GenerateRandomTensor(is_training ? empty_shape + : channel_shape); + auto side_input_t = GenerateRandomTensor({2, 8, 8, num_channels}); + + GrapplerItem item; + item.fetch = {"fetch"}; + item.feed = {{"input", input_t}, {"scale", scale_t}, + {"offset", offset_t}, {"mean", mean_t}, + {"var", var_t}, {"side_input", side_input_t}}; + TF_ASSERT_OK(s.ToGraphDef(&item.graph)); + + // Place all nodes on GPU. + for (int i = 0; i < item.graph.node_size(); ++i) { + item.graph.mutable_node(i)->set_device("/device:GPU:0"); + } + + Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "relu") { + EXPECT_EQ(node.op(), "Identity"); + ASSERT_EQ(node.input_size(), 1); + EXPECT_EQ(node.input(0), "fused_batch_norm"); + found++; + } + if (node.name() == "fused_batch_norm") { + EXPECT_EQ(node.op(), "_FusedBatchNormEx"); + ASSERT_EQ(node.input_size(), 6); + EXPECT_EQ(node.input(0), "input_cast"); + EXPECT_EQ(node.input(1), "scale"); + EXPECT_EQ(node.input(2), "offset"); + EXPECT_EQ(node.input(3), "mean"); + EXPECT_EQ(node.input(4), "var"); + EXPECT_EQ(node.input(5), "side_input_cast"); + + auto attr = node.attr(); + EXPECT_EQ(attr["num_side_inputs"].i(), 1); + EXPECT_EQ(attr["activation_mode"].s(), "Relu"); + found++; + } + } + EXPECT_EQ(found, 2); + + if (GetNumAvailableGPUs() > 0) { + auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); + ASSERT_EQ(tensors_expected.size(), 1); + auto tensors = EvaluateNodes(output, item.fetch, item.feed); + ASSERT_EQ(tensors.size(), 1); + test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2); } } - EXPECT_EQ(found, 2); - - if (GetNumAvailableGPUs() > 0) { - auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); - ASSERT_EQ(tensors_expected.size(), 1); - auto tensors = EvaluateNodes(output, item.fetch, item.feed); - ASSERT_EQ(tensors.size(), 1); - test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2); - } -#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) } TEST_F(RemapperTest, FuseConv2DWithBias) {