Add a remapper pattern for FusedBatchNorm[is_training=False].
tf_cnn_benchmarks --forward_only=true --model=resnet50: fp32: 1010 images/sec -> 1145 images/sec fp16: 1990 images/sec -> 3010 images/sec PiperOrigin-RevId: 253610151
This commit is contained in:
parent
ce59aa0be8
commit
d73778d205
@ -761,8 +761,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
||||
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
||||
.ok())
|
||||
return false;
|
||||
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
|
||||
if (!is_training) return false;
|
||||
|
||||
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
||||
// inputs and activation, and it has its own limitations. In inference mode
|
||||
@ -796,7 +794,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
||||
!HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
|
||||
return false;
|
||||
|
||||
// Check that only one node consumes the output of a FusedBatchNorm.
|
||||
// Check that only one node consumes the 0-th output of a FusedBatchNorm.
|
||||
if (HasControlFaninOrFanout(fused_batch_norm) ||
|
||||
!HasAtMostOneFanoutAtPort0(fused_batch_norm) ||
|
||||
IsInPreserveSet(ctx, fused_batch_norm_node_def))
|
||||
@ -1216,11 +1214,14 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx,
|
||||
const NodeDef& activation = graph->node(matched.activation);
|
||||
|
||||
VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
|
||||
<< " side_input="
|
||||
<< " activation=" << activation.name() << " side_input="
|
||||
<< (matched.side_input != kMissingIndex
|
||||
? graph->node(matched.side_input).name()
|
||||
: "<none>")
|
||||
<< " activation=" << activation.name()
|
||||
<< " invalidated="
|
||||
<< (matched.invalidated != kMissingIndex
|
||||
? graph->node(matched.invalidated).name()
|
||||
: "<none>")
|
||||
<< " fused_batch_norm=" << fused_batch_norm.name();
|
||||
|
||||
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
|
||||
@ -1512,23 +1513,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
||||
|
||||
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
// Supported graph patterns.
|
||||
// clang-format off
|
||||
FusedBatchNorm fused_batch_norm;
|
||||
FusedBatchNormEx fused_batch_norm_ex;
|
||||
ContractionWithBiasAdd contract_with_bias;
|
||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||
#ifndef INTEL_MKL
|
||||
ContractionWithBatchNorm contract_with_batch_norm;
|
||||
ContractionWithBatchNormAndActivation contract_with_batch_norm_and_activation;
|
||||
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
||||
#endif // !INTEL_MKL
|
||||
#ifdef INTEL_MKL
|
||||
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
|
||||
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
|
||||
#endif // INTEL_MKL
|
||||
// clang-format on
|
||||
|
||||
GrapplerItem mutable_item = item;
|
||||
Status status;
|
||||
RemapperContext ctx(&mutable_item, &status);
|
||||
@ -1556,6 +1540,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
|
||||
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
|
||||
|
||||
if (!item.optimization_options().is_eager_mode) {
|
||||
// Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
|
||||
if (FindContractionWithBiasAndAddActivation(
|
||||
@ -1578,6 +1565,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
#endif //! INTEL_MKL
|
||||
|
||||
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
|
||||
ContractionWithBiasAdd contract_with_bias;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
||||
TF_RETURN_IF_ERROR(AddFusedContractionNode(
|
||||
@ -1586,6 +1574,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
|
||||
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
|
||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindContractionWithBiasAndActivation(
|
||||
ctx, i, &contract_with_bias_and_activation)) {
|
||||
@ -1603,6 +1592,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
||||
#ifndef INTEL_MKL
|
||||
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
||||
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -1612,6 +1602,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
|
||||
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
||||
ContractionWithBatchNorm contract_with_batch_norm;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
|
||||
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
|
||||
@ -1621,6 +1612,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
|
||||
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
|
||||
ContractionWithBatchNormAndActivation
|
||||
contract_with_batch_norm_and_activation;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindConv2DWithBatchNormAndActivation(
|
||||
ctx, i, &contract_with_batch_norm_and_activation)) {
|
||||
@ -1644,6 +1637,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
|
||||
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
|
||||
FusedBatchNormEx fused_batch_norm_ex;
|
||||
if (allow_non_differentiable_rewrites &&
|
||||
FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
|
||||
TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
|
||||
@ -1653,6 +1647,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
|
||||
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
||||
// we can therefore replace the op with a much cheaper set of primitives.
|
||||
FusedBatchNorm fused_batch_norm;
|
||||
if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
|
||||
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
|
||||
continue;
|
||||
|
@ -112,185 +112,225 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-5);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (bool is_training : {true, false}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
|
||||
"CUDNN_VERSION >= 7402.";
|
||||
#else
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
||||
auto channels_shape = ops::Placeholder::Shape({24});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
||||
"NHWC"));
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t},
|
||||
{"scale", scale_t},
|
||||
{"offset", offset_t},
|
||||
{"mean", mean_t},
|
||||
{"var", var_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on GPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||
found++;
|
||||
if (is_training) {
|
||||
LOG(INFO) << "Skip FuseBatchNormWithRelu"
|
||||
<< "[is_training=" << is_training << "] "
|
||||
<< "test. It requires CUDNN_VERSION >= 7402.";
|
||||
continue;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||
ASSERT_EQ(node.input_size(), 5);
|
||||
EXPECT_EQ(node.input(0), "input_cast");
|
||||
EXPECT_EQ(node.input(1), "scale");
|
||||
EXPECT_EQ(node.input(2), "offset");
|
||||
EXPECT_EQ(node.input(3), "mean");
|
||||
EXPECT_EQ(node.input(4), "var");
|
||||
#endif
|
||||
|
||||
auto attr = node.attr();
|
||||
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
|
||||
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||
found++;
|
||||
#if !defined(GOOGLE_CUDA)
|
||||
if (!is_training) {
|
||||
LOG(INFO) << "Skip FuseBatchNormWithRelu"
|
||||
<< "[is_training=" << is_training << "]";
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
const int num_channels = 24;
|
||||
|
||||
TensorShape channel_shape({num_channels});
|
||||
TensorShape empty_shape({0});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||
ops::Placeholder::Shape({2, 8, 8, num_channels}));
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||
.Epsilon(epsilon)
|
||||
.DataFormat("NHWC"));
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t},
|
||||
{"scale", scale_t},
|
||||
{"offset", offset_t},
|
||||
{"mean", mean_t},
|
||||
{"var", var_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on GPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||
ASSERT_EQ(node.input_size(), 5);
|
||||
EXPECT_EQ(node.input(0), "input_cast");
|
||||
EXPECT_EQ(node.input(1), "scale");
|
||||
EXPECT_EQ(node.input(2), "offset");
|
||||
EXPECT_EQ(node.input(3), "mean");
|
||||
EXPECT_EQ(node.input(4), "var");
|
||||
|
||||
auto attr = node.attr();
|
||||
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
|
||||
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 2);
|
||||
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 2);
|
||||
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||
}
|
||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (bool is_training : {true, false}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
|
||||
"CUDNN_VERSION >= 7402.";
|
||||
#else
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
||||
auto channels_shape = ops::Placeholder::Shape({24});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
||||
auto side_input =
|
||||
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
|
||||
auto side_input_cast =
|
||||
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
||||
"NHWC"));
|
||||
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
||||
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"scale", scale_t},
|
||||
{"offset", offset_t}, {"mean", mean_t},
|
||||
{"var", var_t}, {"side_input", side_input_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on GPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||
found++;
|
||||
if (is_training) {
|
||||
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
|
||||
<< "[is_training=" << is_training << "] "
|
||||
<< "test. It requires CUDNN_VERSION >= 7402.";
|
||||
continue;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||
ASSERT_EQ(node.input_size(), 6);
|
||||
EXPECT_EQ(node.input(0), "input_cast");
|
||||
EXPECT_EQ(node.input(1), "scale");
|
||||
EXPECT_EQ(node.input(2), "offset");
|
||||
EXPECT_EQ(node.input(3), "mean");
|
||||
EXPECT_EQ(node.input(4), "var");
|
||||
EXPECT_EQ(node.input(5), "side_input_cast");
|
||||
#endif
|
||||
|
||||
auto attr = node.attr();
|
||||
EXPECT_EQ(attr["num_side_inputs"].i(), 1);
|
||||
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||
found++;
|
||||
#if !defined(GOOGLE_CUDA)
|
||||
if (!is_training) {
|
||||
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
|
||||
<< "[is_training=" << is_training << "]";
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
const int num_channels = 24;
|
||||
|
||||
TensorShape input_shape({2, 8, 8, num_channels});
|
||||
TensorShape channel_shape({num_channels});
|
||||
TensorShape empty_shape({0});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||
ops::Placeholder::Shape(input_shape));
|
||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT);
|
||||
auto side_input_cast =
|
||||
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
||||
|
||||
float epsilon = 0.1f;
|
||||
auto fbn = ops::FusedBatchNormV3(
|
||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||
.Epsilon(epsilon)
|
||||
.DataFormat("NHWC"));
|
||||
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||
|
||||
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
|
||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||
: channel_shape);
|
||||
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"fetch"};
|
||||
item.feed = {{"input", input_t}, {"scale", scale_t},
|
||||
{"offset", offset_t}, {"mean", mean_t},
|
||||
{"var", var_t}, {"side_input", side_input_t}};
|
||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
// Place all nodes on GPU.
|
||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||
}
|
||||
|
||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||
GraphDef output;
|
||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
|
||||
int found = 0;
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "relu") {
|
||||
EXPECT_EQ(node.op(), "Identity");
|
||||
ASSERT_EQ(node.input_size(), 1);
|
||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||
found++;
|
||||
}
|
||||
if (node.name() == "fused_batch_norm") {
|
||||
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||
ASSERT_EQ(node.input_size(), 6);
|
||||
EXPECT_EQ(node.input(0), "input_cast");
|
||||
EXPECT_EQ(node.input(1), "scale");
|
||||
EXPECT_EQ(node.input(2), "offset");
|
||||
EXPECT_EQ(node.input(3), "mean");
|
||||
EXPECT_EQ(node.input(4), "var");
|
||||
EXPECT_EQ(node.input(5), "side_input_cast");
|
||||
|
||||
auto attr = node.attr();
|
||||
EXPECT_EQ(attr["num_side_inputs"].i(), 1);
|
||||
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||
found++;
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 2);
|
||||
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||
}
|
||||
}
|
||||
EXPECT_EQ(found, 2);
|
||||
|
||||
if (GetNumAvailableGPUs() > 0) {
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||
}
|
||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||
}
|
||||
|
||||
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
||||
|
Loading…
Reference in New Issue
Block a user