Add a remapper pattern for FusedBatchNorm[is_training=False].
tf_cnn_benchmarks --forward_only=true --model=resnet50: fp32: 1010 images/sec -> 1145 images/sec fp16: 1990 images/sec -> 3010 images/sec PiperOrigin-RevId: 253610151
This commit is contained in:
parent
ce59aa0be8
commit
d73778d205
@ -761,8 +761,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
|||||||
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
|
||||||
.ok())
|
.ok())
|
||||||
return false;
|
return false;
|
||||||
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
|
|
||||||
if (!is_training) return false;
|
|
||||||
|
|
||||||
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
|
||||||
// inputs and activation, and it has its own limitations. In inference mode
|
// inputs and activation, and it has its own limitations. In inference mode
|
||||||
@ -796,7 +794,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
|||||||
!HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
|
!HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
// Check that only one node consumes the output of a FusedBatchNorm.
|
// Check that only one node consumes the 0-th output of a FusedBatchNorm.
|
||||||
if (HasControlFaninOrFanout(fused_batch_norm) ||
|
if (HasControlFaninOrFanout(fused_batch_norm) ||
|
||||||
!HasAtMostOneFanoutAtPort0(fused_batch_norm) ||
|
!HasAtMostOneFanoutAtPort0(fused_batch_norm) ||
|
||||||
IsInPreserveSet(ctx, fused_batch_norm_node_def))
|
IsInPreserveSet(ctx, fused_batch_norm_node_def))
|
||||||
@ -1216,11 +1214,14 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx,
|
|||||||
const NodeDef& activation = graph->node(matched.activation);
|
const NodeDef& activation = graph->node(matched.activation);
|
||||||
|
|
||||||
VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
|
VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
|
||||||
<< " side_input="
|
<< " activation=" << activation.name() << " side_input="
|
||||||
<< (matched.side_input != kMissingIndex
|
<< (matched.side_input != kMissingIndex
|
||||||
? graph->node(matched.side_input).name()
|
? graph->node(matched.side_input).name()
|
||||||
: "<none>")
|
: "<none>")
|
||||||
<< " activation=" << activation.name()
|
<< " invalidated="
|
||||||
|
<< (matched.invalidated != kMissingIndex
|
||||||
|
? graph->node(matched.invalidated).name()
|
||||||
|
: "<none>")
|
||||||
<< " fused_batch_norm=" << fused_batch_norm.name();
|
<< " fused_batch_norm=" << fused_batch_norm.name();
|
||||||
|
|
||||||
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
|
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
|
||||||
@ -1512,23 +1513,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
|
|||||||
|
|
||||||
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||||
GraphDef* optimized_graph) {
|
GraphDef* optimized_graph) {
|
||||||
// Supported graph patterns.
|
|
||||||
// clang-format off
|
|
||||||
FusedBatchNorm fused_batch_norm;
|
|
||||||
FusedBatchNormEx fused_batch_norm_ex;
|
|
||||||
ContractionWithBiasAdd contract_with_bias;
|
|
||||||
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
|
||||||
#ifndef INTEL_MKL
|
|
||||||
ContractionWithBatchNorm contract_with_batch_norm;
|
|
||||||
ContractionWithBatchNormAndActivation contract_with_batch_norm_and_activation;
|
|
||||||
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
|
||||||
#endif // !INTEL_MKL
|
|
||||||
#ifdef INTEL_MKL
|
|
||||||
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
|
|
||||||
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
|
|
||||||
#endif // INTEL_MKL
|
|
||||||
// clang-format on
|
|
||||||
|
|
||||||
GrapplerItem mutable_item = item;
|
GrapplerItem mutable_item = item;
|
||||||
Status status;
|
Status status;
|
||||||
RemapperContext ctx(&mutable_item, &status);
|
RemapperContext ctx(&mutable_item, &status);
|
||||||
@ -1556,6 +1540,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
|
||||||
|
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
|
||||||
|
|
||||||
if (!item.optimization_options().is_eager_mode) {
|
if (!item.optimization_options().is_eager_mode) {
|
||||||
// Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
|
// Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
|
||||||
if (FindContractionWithBiasAndAddActivation(
|
if (FindContractionWithBiasAndAddActivation(
|
||||||
@ -1578,6 +1565,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
#endif //! INTEL_MKL
|
#endif //! INTEL_MKL
|
||||||
|
|
||||||
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
|
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
|
||||||
|
ContractionWithBiasAdd contract_with_bias;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
FindContractionWithBias(ctx, i, &contract_with_bias)) {
|
||||||
TF_RETURN_IF_ERROR(AddFusedContractionNode(
|
TF_RETURN_IF_ERROR(AddFusedContractionNode(
|
||||||
@ -1586,6 +1574,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
|
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
|
||||||
|
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindContractionWithBiasAndActivation(
|
FindContractionWithBiasAndActivation(
|
||||||
ctx, i, &contract_with_bias_and_activation)) {
|
ctx, i, &contract_with_bias_and_activation)) {
|
||||||
@ -1603,6 +1592,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
// Remove this once TF-MKL supports _FusedConv2D with these operations.
|
||||||
#ifndef INTEL_MKL
|
#ifndef INTEL_MKL
|
||||||
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
|
||||||
|
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
|
FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -1612,6 +1602,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
|
||||||
|
ContractionWithBatchNorm contract_with_batch_norm;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
|
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
|
||||||
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
|
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
|
||||||
@ -1621,6 +1612,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
|
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
|
||||||
|
ContractionWithBatchNormAndActivation
|
||||||
|
contract_with_batch_norm_and_activation;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindConv2DWithBatchNormAndActivation(
|
FindConv2DWithBatchNormAndActivation(
|
||||||
ctx, i, &contract_with_batch_norm_and_activation)) {
|
ctx, i, &contract_with_batch_norm_and_activation)) {
|
||||||
@ -1644,6 +1637,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
|
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
|
||||||
|
FusedBatchNormEx fused_batch_norm_ex;
|
||||||
if (allow_non_differentiable_rewrites &&
|
if (allow_non_differentiable_rewrites &&
|
||||||
FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
|
FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
|
||||||
TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
|
TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
|
||||||
@ -1653,6 +1647,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
|
|||||||
|
|
||||||
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
// During inference, most of the inputs to FusedBatchNorm are constant, and
|
||||||
// we can therefore replace the op with a much cheaper set of primitives.
|
// we can therefore replace the op with a much cheaper set of primitives.
|
||||||
|
FusedBatchNorm fused_batch_norm;
|
||||||
if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
|
if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
|
||||||
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
|
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
|
||||||
continue;
|
continue;
|
||||||
|
@ -112,185 +112,225 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
|
|||||||
ASSERT_EQ(tensors_expected.size(), 1);
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
ASSERT_EQ(tensors.size(), 1);
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-5);
|
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
|
||||||
using ::tensorflow::ops::Placeholder;
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
for (bool is_training : {true, false}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
|
if (is_training) {
|
||||||
"CUDNN_VERSION >= 7402.";
|
LOG(INFO) << "Skip FuseBatchNormWithRelu"
|
||||||
#else
|
<< "[is_training=" << is_training << "] "
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
<< "test. It requires CUDNN_VERSION >= 7402.";
|
||||||
|
continue;
|
||||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
|
||||||
auto channels_shape = ops::Placeholder::Shape({24});
|
|
||||||
|
|
||||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
|
||||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
|
||||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
|
||||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
|
||||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
|
||||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
|
||||||
|
|
||||||
float epsilon = 0.1f;
|
|
||||||
auto fbn = ops::FusedBatchNormV3(
|
|
||||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
|
||||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
|
||||||
"NHWC"));
|
|
||||||
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
|
||||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
|
||||||
|
|
||||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
|
||||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
|
||||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
|
||||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
|
||||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
|
||||||
|
|
||||||
GrapplerItem item;
|
|
||||||
item.fetch = {"fetch"};
|
|
||||||
item.feed = {{"input", input_t},
|
|
||||||
{"scale", scale_t},
|
|
||||||
{"offset", offset_t},
|
|
||||||
{"mean", mean_t},
|
|
||||||
{"var", var_t}};
|
|
||||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
|
||||||
|
|
||||||
// Place all nodes on GPU.
|
|
||||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
|
||||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
|
||||||
}
|
|
||||||
|
|
||||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
|
||||||
GraphDef output;
|
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
|
||||||
|
|
||||||
int found = 0;
|
|
||||||
for (const NodeDef& node : output.node()) {
|
|
||||||
if (node.name() == "relu") {
|
|
||||||
EXPECT_EQ(node.op(), "Identity");
|
|
||||||
ASSERT_EQ(node.input_size(), 1);
|
|
||||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
|
||||||
found++;
|
|
||||||
}
|
}
|
||||||
if (node.name() == "fused_batch_norm") {
|
#endif
|
||||||
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
|
||||||
ASSERT_EQ(node.input_size(), 5);
|
|
||||||
EXPECT_EQ(node.input(0), "input_cast");
|
|
||||||
EXPECT_EQ(node.input(1), "scale");
|
|
||||||
EXPECT_EQ(node.input(2), "offset");
|
|
||||||
EXPECT_EQ(node.input(3), "mean");
|
|
||||||
EXPECT_EQ(node.input(4), "var");
|
|
||||||
|
|
||||||
auto attr = node.attr();
|
#if !defined(GOOGLE_CUDA)
|
||||||
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
|
if (!is_training) {
|
||||||
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
LOG(INFO) << "Skip FuseBatchNormWithRelu"
|
||||||
found++;
|
<< "[is_training=" << is_training << "]";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const int num_channels = 24;
|
||||||
|
|
||||||
|
TensorShape channel_shape({num_channels});
|
||||||
|
TensorShape empty_shape({0});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape({2, 8, 8, num_channels}));
|
||||||
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||||
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||||
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||||
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||||
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||||
|
|
||||||
|
float epsilon = 0.1f;
|
||||||
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||||
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat("NHWC"));
|
||||||
|
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
|
||||||
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
|
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||||
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_t},
|
||||||
|
{"scale", scale_t},
|
||||||
|
{"offset", offset_t},
|
||||||
|
{"mean", mean_t},
|
||||||
|
{"var", var_t}};
|
||||||
|
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on GPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "relu") {
|
||||||
|
EXPECT_EQ(node.op(), "Identity");
|
||||||
|
ASSERT_EQ(node.input_size(), 1);
|
||||||
|
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "fused_batch_norm") {
|
||||||
|
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||||
|
ASSERT_EQ(node.input_size(), 5);
|
||||||
|
EXPECT_EQ(node.input(0), "input_cast");
|
||||||
|
EXPECT_EQ(node.input(1), "scale");
|
||||||
|
EXPECT_EQ(node.input(2), "offset");
|
||||||
|
EXPECT_EQ(node.input(3), "mean");
|
||||||
|
EXPECT_EQ(node.input(4), "var");
|
||||||
|
|
||||||
|
auto attr = node.attr();
|
||||||
|
EXPECT_EQ(attr["num_side_inputs"].i(), 0);
|
||||||
|
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 2);
|
||||||
|
|
||||||
|
if (GetNumAvailableGPUs() > 0) {
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EXPECT_EQ(found, 2);
|
|
||||||
|
|
||||||
if (GetNumAvailableGPUs() > 0) {
|
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
|
||||||
ASSERT_EQ(tensors_expected.size(), 1);
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
|
||||||
ASSERT_EQ(tensors.size(), 1);
|
|
||||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
|
||||||
}
|
|
||||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
|
||||||
using ::tensorflow::ops::Placeholder;
|
using ::tensorflow::ops::Placeholder;
|
||||||
|
|
||||||
|
for (bool is_training : {true, false}) {
|
||||||
|
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||||
|
|
||||||
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
||||||
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
|
if (is_training) {
|
||||||
"CUDNN_VERSION >= 7402.";
|
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
|
||||||
#else
|
<< "[is_training=" << is_training << "] "
|
||||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
<< "test. It requires CUDNN_VERSION >= 7402.";
|
||||||
|
continue;
|
||||||
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
|
|
||||||
auto channels_shape = ops::Placeholder::Shape({24});
|
|
||||||
|
|
||||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
|
||||||
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
|
||||||
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
|
|
||||||
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
|
|
||||||
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
|
|
||||||
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
|
|
||||||
auto side_input =
|
|
||||||
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
|
|
||||||
auto side_input_cast =
|
|
||||||
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
|
||||||
|
|
||||||
float epsilon = 0.1f;
|
|
||||||
auto fbn = ops::FusedBatchNormV3(
|
|
||||||
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
|
||||||
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
|
|
||||||
"NHWC"));
|
|
||||||
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
|
||||||
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
|
||||||
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
|
||||||
|
|
||||||
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
|
||||||
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
|
|
||||||
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
|
|
||||||
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
|
||||||
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
|
|
||||||
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
|
|
||||||
|
|
||||||
GrapplerItem item;
|
|
||||||
item.fetch = {"fetch"};
|
|
||||||
item.feed = {{"input", input_t}, {"scale", scale_t},
|
|
||||||
{"offset", offset_t}, {"mean", mean_t},
|
|
||||||
{"var", var_t}, {"side_input", side_input_t}};
|
|
||||||
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
|
||||||
|
|
||||||
// Place all nodes on GPU.
|
|
||||||
for (int i = 0; i < item.graph.node_size(); ++i) {
|
|
||||||
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
|
||||||
}
|
|
||||||
|
|
||||||
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
|
||||||
GraphDef output;
|
|
||||||
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
|
||||||
|
|
||||||
int found = 0;
|
|
||||||
for (const NodeDef& node : output.node()) {
|
|
||||||
if (node.name() == "relu") {
|
|
||||||
EXPECT_EQ(node.op(), "Identity");
|
|
||||||
ASSERT_EQ(node.input_size(), 1);
|
|
||||||
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
|
||||||
found++;
|
|
||||||
}
|
}
|
||||||
if (node.name() == "fused_batch_norm") {
|
#endif
|
||||||
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
|
||||||
ASSERT_EQ(node.input_size(), 6);
|
|
||||||
EXPECT_EQ(node.input(0), "input_cast");
|
|
||||||
EXPECT_EQ(node.input(1), "scale");
|
|
||||||
EXPECT_EQ(node.input(2), "offset");
|
|
||||||
EXPECT_EQ(node.input(3), "mean");
|
|
||||||
EXPECT_EQ(node.input(4), "var");
|
|
||||||
EXPECT_EQ(node.input(5), "side_input_cast");
|
|
||||||
|
|
||||||
auto attr = node.attr();
|
#if !defined(GOOGLE_CUDA)
|
||||||
EXPECT_EQ(attr["num_side_inputs"].i(), 1);
|
if (!is_training) {
|
||||||
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
|
||||||
found++;
|
<< "[is_training=" << is_training << "]";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
const int num_channels = 24;
|
||||||
|
|
||||||
|
TensorShape input_shape({2, 8, 8, num_channels});
|
||||||
|
TensorShape channel_shape({num_channels});
|
||||||
|
TensorShape empty_shape({0});
|
||||||
|
|
||||||
|
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
|
||||||
|
ops::Placeholder::Shape(input_shape));
|
||||||
|
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
|
||||||
|
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
|
||||||
|
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
|
||||||
|
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
|
||||||
|
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
|
||||||
|
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT);
|
||||||
|
auto side_input_cast =
|
||||||
|
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
|
||||||
|
|
||||||
|
float epsilon = 0.1f;
|
||||||
|
auto fbn = ops::FusedBatchNormV3(
|
||||||
|
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
|
||||||
|
ops::FusedBatchNormV3::IsTraining(is_training)
|
||||||
|
.Epsilon(epsilon)
|
||||||
|
.DataFormat("NHWC"));
|
||||||
|
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
|
||||||
|
auto relu = ops::Relu(s.WithOpName("relu"), add);
|
||||||
|
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
|
||||||
|
|
||||||
|
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
|
||||||
|
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
|
||||||
|
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
|
||||||
|
: channel_shape);
|
||||||
|
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
|
||||||
|
|
||||||
|
GrapplerItem item;
|
||||||
|
item.fetch = {"fetch"};
|
||||||
|
item.feed = {{"input", input_t}, {"scale", scale_t},
|
||||||
|
{"offset", offset_t}, {"mean", mean_t},
|
||||||
|
{"var", var_t}, {"side_input", side_input_t}};
|
||||||
|
TF_ASSERT_OK(s.ToGraphDef(&item.graph));
|
||||||
|
|
||||||
|
// Place all nodes on GPU.
|
||||||
|
for (int i = 0; i < item.graph.node_size(); ++i) {
|
||||||
|
item.graph.mutable_node(i)->set_device("/device:GPU:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
Remapper optimizer(RewriterConfig::AGGRESSIVE); // trust placeholders shape
|
||||||
|
GraphDef output;
|
||||||
|
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
|
||||||
|
int found = 0;
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "relu") {
|
||||||
|
EXPECT_EQ(node.op(), "Identity");
|
||||||
|
ASSERT_EQ(node.input_size(), 1);
|
||||||
|
EXPECT_EQ(node.input(0), "fused_batch_norm");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
if (node.name() == "fused_batch_norm") {
|
||||||
|
EXPECT_EQ(node.op(), "_FusedBatchNormEx");
|
||||||
|
ASSERT_EQ(node.input_size(), 6);
|
||||||
|
EXPECT_EQ(node.input(0), "input_cast");
|
||||||
|
EXPECT_EQ(node.input(1), "scale");
|
||||||
|
EXPECT_EQ(node.input(2), "offset");
|
||||||
|
EXPECT_EQ(node.input(3), "mean");
|
||||||
|
EXPECT_EQ(node.input(4), "var");
|
||||||
|
EXPECT_EQ(node.input(5), "side_input_cast");
|
||||||
|
|
||||||
|
auto attr = node.attr();
|
||||||
|
EXPECT_EQ(attr["num_side_inputs"].i(), 1);
|
||||||
|
EXPECT_EQ(attr["activation_mode"].s(), "Relu");
|
||||||
|
found++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EXPECT_EQ(found, 2);
|
||||||
|
|
||||||
|
if (GetNumAvailableGPUs() > 0) {
|
||||||
|
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
||||||
|
ASSERT_EQ(tensors_expected.size(), 1);
|
||||||
|
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
||||||
|
ASSERT_EQ(tensors.size(), 1);
|
||||||
|
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
EXPECT_EQ(found, 2);
|
|
||||||
|
|
||||||
if (GetNumAvailableGPUs() > 0) {
|
|
||||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed);
|
|
||||||
ASSERT_EQ(tensors_expected.size(), 1);
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch, item.feed);
|
|
||||||
ASSERT_EQ(tensors.size(), 1);
|
|
||||||
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
|
|
||||||
}
|
|
||||||
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
TEST_F(RemapperTest, FuseConv2DWithBias) {
|
||||||
|
Loading…
Reference in New Issue
Block a user