Add a remapper pattern for FusedBatchNorm[is_training=False].

tf_cnn_benchmarks --forward_only=true --model=resnet50:
  fp32: 1010 images/sec -> 1145 images/sec
  fp16: 1990 images/sec -> 3010 images/sec

PiperOrigin-RevId: 253610151
This commit is contained in:
Eugene Zhulenev 2019-06-17 10:36:15 -07:00 committed by TensorFlower Gardener
parent ce59aa0be8
commit d73778d205
2 changed files with 215 additions and 180 deletions

View File

@ -761,8 +761,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training) if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
.ok()) .ok())
return false; return false;
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
if (!is_training) return false;
// In training mode we rely on cuDNN for computing FusedBatchNorm with side // In training mode we rely on cuDNN for computing FusedBatchNorm with side
// inputs and activation, and it has its own limitations. In inference mode // inputs and activation, and it has its own limitations. In inference mode
@ -796,7 +794,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
!HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U")) !HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
return false; return false;
// Check that only one node consumes the output of a FusedBatchNorm. // Check that only one node consumes the 0-th output of a FusedBatchNorm.
if (HasControlFaninOrFanout(fused_batch_norm) || if (HasControlFaninOrFanout(fused_batch_norm) ||
!HasAtMostOneFanoutAtPort0(fused_batch_norm) || !HasAtMostOneFanoutAtPort0(fused_batch_norm) ||
IsInPreserveSet(ctx, fused_batch_norm_node_def)) IsInPreserveSet(ctx, fused_batch_norm_node_def))
@ -1216,11 +1214,14 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx,
const NodeDef& activation = graph->node(matched.activation); const NodeDef& activation = graph->node(matched.activation);
VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:" VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
<< " side_input=" << " activation=" << activation.name() << " side_input="
<< (matched.side_input != kMissingIndex << (matched.side_input != kMissingIndex
? graph->node(matched.side_input).name() ? graph->node(matched.side_input).name()
: "<none>") : "<none>")
<< " activation=" << activation.name() << " invalidated="
<< (matched.invalidated != kMissingIndex
? graph->node(matched.invalidated).name()
: "<none>")
<< " fused_batch_norm=" << fused_batch_norm.name(); << " fused_batch_norm=" << fused_batch_norm.name();
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>. // Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
@ -1512,23 +1513,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item, Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) { GraphDef* optimized_graph) {
// Supported graph patterns.
// clang-format off
FusedBatchNorm fused_batch_norm;
FusedBatchNormEx fused_batch_norm_ex;
ContractionWithBiasAdd contract_with_bias;
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
#ifndef INTEL_MKL
ContractionWithBatchNorm contract_with_batch_norm;
ContractionWithBatchNormAndActivation contract_with_batch_norm_and_activation;
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
#endif // !INTEL_MKL
#ifdef INTEL_MKL
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
#endif // INTEL_MKL
// clang-format on
GrapplerItem mutable_item = item; GrapplerItem mutable_item = item;
Status status; Status status;
RemapperContext ctx(&mutable_item, &status); RemapperContext ctx(&mutable_item, &status);
@ -1556,6 +1540,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
#ifdef INTEL_MKL #ifdef INTEL_MKL
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
if (!item.optimization_options().is_eager_mode) { if (!item.optimization_options().is_eager_mode) {
// Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D. // Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
if (FindContractionWithBiasAndAddActivation( if (FindContractionWithBiasAndAddActivation(
@ -1578,6 +1565,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
#endif //! INTEL_MKL #endif //! INTEL_MKL
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul} // Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
ContractionWithBiasAdd contract_with_bias;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindContractionWithBias(ctx, i, &contract_with_bias)) { FindContractionWithBias(ctx, i, &contract_with_bias)) {
TF_RETURN_IF_ERROR(AddFusedContractionNode( TF_RETURN_IF_ERROR(AddFusedContractionNode(
@ -1586,6 +1574,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}. // Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindContractionWithBiasAndActivation( FindContractionWithBiasAndActivation(
ctx, i, &contract_with_bias_and_activation)) { ctx, i, &contract_with_bias_and_activation)) {
@ -1603,6 +1592,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
// Remove this once TF-MKL supports _FusedConv2D with these operations. // Remove this once TF-MKL supports _FusedConv2D with these operations.
#ifndef INTEL_MKL #ifndef INTEL_MKL
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze. // Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) { FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -1612,6 +1602,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D; // Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
ContractionWithBatchNorm contract_with_batch_norm;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) { FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm, TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
@ -1621,6 +1612,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D; // Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
ContractionWithBatchNormAndActivation
contract_with_batch_norm_and_activation;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindConv2DWithBatchNormAndActivation( FindConv2DWithBatchNormAndActivation(
ctx, i, &contract_with_batch_norm_and_activation)) { ctx, i, &contract_with_batch_norm_and_activation)) {
@ -1644,6 +1637,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
} }
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx. // Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
FusedBatchNormEx fused_batch_norm_ex;
if (allow_non_differentiable_rewrites && if (allow_non_differentiable_rewrites &&
FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) { FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
TF_RETURN_IF_ERROR(AddFusedBatchNormExNode( TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
@ -1653,6 +1647,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
// During inference, most of the inputs to FusedBatchNorm are constant, and // During inference, most of the inputs to FusedBatchNorm are constant, and
// we can therefore replace the op with a much cheaper set of primitives. // we can therefore replace the op with a much cheaper set of primitives.
FusedBatchNorm fused_batch_norm;
if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) { if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm)); TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
continue; continue;

View File

@ -112,185 +112,225 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
ASSERT_EQ(tensors_expected.size(), 1); ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch); auto tensors = EvaluateNodes(output, item.fetch);
ASSERT_EQ(tensors.size(), 1); ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-5); test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-3);
} }
} }
TEST_F(RemapperTest, FuseBatchNormWithRelu) { TEST_F(RemapperTest, FuseBatchNormWithRelu) {
using ::tensorflow::ops::Placeholder; using ::tensorflow::ops::Placeholder;
for (bool is_training : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) #if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires " if (is_training) {
"CUDNN_VERSION >= 7402."; LOG(INFO) << "Skip FuseBatchNormWithRelu"
#else << "[is_training=" << is_training << "] "
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); << "test. It requires CUDNN_VERSION >= 7402.";
continue;
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t},
{"scale", scale_t},
{"offset", offset_t},
{"mean", mean_t},
{"var", var_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "fused_batch_norm");
found++;
} }
if (node.name() == "fused_batch_norm") { #endif
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
ASSERT_EQ(node.input_size(), 5);
EXPECT_EQ(node.input(0), "input_cast");
EXPECT_EQ(node.input(1), "scale");
EXPECT_EQ(node.input(2), "offset");
EXPECT_EQ(node.input(3), "mean");
EXPECT_EQ(node.input(4), "var");
auto attr = node.attr(); #if !defined(GOOGLE_CUDA)
EXPECT_EQ(attr["num_side_inputs"].i(), 0); if (!is_training) {
EXPECT_EQ(attr["activation_mode"].s(), "Relu"); LOG(INFO) << "Skip FuseBatchNormWithRelu"
found++; << "[is_training=" << is_training << "]";
continue;
}
#endif
const int num_channels = 24;
TensorShape channel_shape({num_channels});
TensorShape empty_shape({0});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape({2, 8, 8, num_channels}));
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat("NHWC"));
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t},
{"scale", scale_t},
{"offset", offset_t},
{"mean", mean_t},
{"var", var_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "fused_batch_norm");
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
ASSERT_EQ(node.input_size(), 5);
EXPECT_EQ(node.input(0), "input_cast");
EXPECT_EQ(node.input(1), "scale");
EXPECT_EQ(node.input(2), "offset");
EXPECT_EQ(node.input(3), "mean");
EXPECT_EQ(node.input(4), "var");
auto attr = node.attr();
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
found++;
}
}
EXPECT_EQ(found, 2);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
} }
} }
EXPECT_EQ(found, 2);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
} }
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) { TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
using ::tensorflow::ops::Placeholder; using ::tensorflow::ops::Placeholder;
for (bool is_training : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402) #if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires " if (is_training) {
"CUDNN_VERSION >= 7402."; LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
#else << "[is_training=" << is_training << "] "
tensorflow::Scope s = tensorflow::Scope::NewRootScope(); << "test. It requires CUDNN_VERSION >= 7402.";
continue;
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
auto side_input =
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
auto side_input_cast =
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
auto relu = ops::Relu(s.WithOpName("relu"), add);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"scale", scale_t},
{"offset", offset_t}, {"mean", mean_t},
{"var", var_t}, {"side_input", side_input_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "fused_batch_norm");
found++;
} }
if (node.name() == "fused_batch_norm") { #endif
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
ASSERT_EQ(node.input_size(), 6);
EXPECT_EQ(node.input(0), "input_cast");
EXPECT_EQ(node.input(1), "scale");
EXPECT_EQ(node.input(2), "offset");
EXPECT_EQ(node.input(3), "mean");
EXPECT_EQ(node.input(4), "var");
EXPECT_EQ(node.input(5), "side_input_cast");
auto attr = node.attr(); #if !defined(GOOGLE_CUDA)
EXPECT_EQ(attr["num_side_inputs"].i(), 1); if (!is_training) {
EXPECT_EQ(attr["activation_mode"].s(), "Relu"); LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
found++; << "[is_training=" << is_training << "]";
continue;
}
#endif
const int num_channels = 24;
TensorShape input_shape({2, 8, 8, num_channels});
TensorShape channel_shape({num_channels});
TensorShape empty_shape({0});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape(input_shape));
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT);
auto side_input_cast =
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat("NHWC"));
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
auto relu = ops::Relu(s.WithOpName("relu"), add);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
GrapplerItem item;
item.fetch = {"fetch"};
item.feed = {{"input", input_t}, {"scale", scale_t},
{"offset", offset_t}, {"mean", mean_t},
{"var", var_t}, {"side_input", side_input_t}};
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
// Place all nodes on GPU.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device("/device:GPU:0");
}
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
GraphDef output;
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
int found = 0;
for (const NodeDef& node : output.node()) {
if (node.name() == "relu") {
EXPECT_EQ(node.op(), "Identity");
ASSERT_EQ(node.input_size(), 1);
EXPECT_EQ(node.input(0), "fused_batch_norm");
found++;
}
if (node.name() == "fused_batch_norm") {
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
ASSERT_EQ(node.input_size(), 6);
EXPECT_EQ(node.input(0), "input_cast");
EXPECT_EQ(node.input(1), "scale");
EXPECT_EQ(node.input(2), "offset");
EXPECT_EQ(node.input(3), "mean");
EXPECT_EQ(node.input(4), "var");
EXPECT_EQ(node.input(5), "side_input_cast");
auto attr = node.attr();
EXPECT_EQ(attr["num_side_inputs"].i(), 1);
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
found++;
}
}
EXPECT_EQ(found, 2);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
} }
} }
EXPECT_EQ(found, 2);
if (GetNumAvailableGPUs() > 0) {
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
ASSERT_EQ(tensors.size(), 1);
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
} }
TEST_F(RemapperTest, FuseConv2DWithBias) { TEST_F(RemapperTest, FuseConv2DWithBias) {