Add a remapper pattern for FusedBatchNorm[is_training=False].

tf_cnn_benchmarks --forward_only=true --model=resnet50:
  fp32: 1010 images/sec -> 1145 images/sec
  fp16: 1990 images/sec -> 3010 images/sec

PiperOrigin-RevId: 253610151
This commit is contained in:
Eugene Zhulenev 2019-06-17 10:36:15 -07:00 committed by TensorFlower Gardener
parent ce59aa0be8
commit d73778d205
2 changed files with 215 additions and 180 deletions

View File

@ -761,8 +761,6 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
if (!GetNodeAttr(*fused_batch_norm_node_def, kIsTraining, &is_training)
.ok())
return false;
// TODO(ezhulenev): Add support for is_training=True and custom CUDA kernel.
if (!is_training) return false;
// In training mode we rely on cuDNN for computing FusedBatchNorm with side
// inputs and activation, and it has its own limitations. In inference mode
@ -796,7 +794,7 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
!HasDataType(fused_batch_norm_node_def, DT_FLOAT, "U"))
return false;
// Check that only one node consumes the output of a FusedBatchNorm.
// Check that only one node consumes the 0-th output of a FusedBatchNorm.
if (HasControlFaninOrFanout(fused_batch_norm) ||
!HasAtMostOneFanoutAtPort0(fused_batch_norm) ||
IsInPreserveSet(ctx, fused_batch_norm_node_def))
@ -1216,11 +1214,14 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx,
const NodeDef& activation = graph->node(matched.activation);
VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:"
<< " side_input="
<< " activation=" << activation.name() << " side_input="
<< (matched.side_input != kMissingIndex
? graph->node(matched.side_input).name()
: "<none>")
<< " activation=" << activation.name()
<< " invalidated="
<< (matched.invalidated != kMissingIndex
? graph->node(matched.invalidated).name()
: "<none>")
<< " fused_batch_norm=" << fused_batch_norm.name();
// Replace FusedBatchNorm with _FusedBatchNormEx + <SideInput> + <Activation>.
@ -1512,23 +1513,6 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) {
Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
// Supported graph patterns.
// clang-format off
FusedBatchNorm fused_batch_norm;
FusedBatchNormEx fused_batch_norm_ex;
ContractionWithBiasAdd contract_with_bias;
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
#ifndef INTEL_MKL
ContractionWithBatchNorm contract_with_batch_norm;
ContractionWithBatchNormAndActivation contract_with_batch_norm_and_activation;
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
#endif // !INTEL_MKL
#ifdef INTEL_MKL
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
#endif // INTEL_MKL
// clang-format on
GrapplerItem mutable_item = item;
Status status;
RemapperContext ctx(&mutable_item, &status);
@ -1556,6 +1540,9 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
}
#ifdef INTEL_MKL
ContractionWithBiasAddAndAdd contract_with_bias_and_add;
ContractionWithBiasAndAddActivation contract_with_bias_and_add_activation;
if (!item.optimization_options().is_eager_mode) {
// Remap Conv2D+BiasAdd+Add+relu into the _FusedConv2D.
if (FindContractionWithBiasAndAddActivation(
@ -1578,6 +1565,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
#endif //! INTEL_MKL
// Remap {Conv2D,MatMul}+BiasAdd into the _Fused{Conv2D,MatMul}
ContractionWithBiasAdd contract_with_bias;
if (allow_non_differentiable_rewrites &&
FindContractionWithBias(ctx, i, &contract_with_bias)) {
TF_RETURN_IF_ERROR(AddFusedContractionNode(
@ -1586,6 +1574,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Remap {Conv2D,MatMul}+BiasAdd+Activation into the _Fused{Conv2D,MatMul}.
ContractionWithBiasAddAndActivation contract_with_bias_and_activation;
if (allow_non_differentiable_rewrites &&
FindContractionWithBiasAndActivation(
ctx, i, &contract_with_bias_and_activation)) {
@ -1603,6 +1592,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
// Remove this once TF-MKL supports _FusedConv2D with these operations.
#ifndef INTEL_MKL
// Remap Conv2D+Squeeze+BiasAdd into the _FusedConv2D+Squeeze.
ContractionWithSqueezeAndBiasAdd contract_with_squeeze_and_bias;
if (allow_non_differentiable_rewrites &&
FindConv2DWithSqueezeAndBias(ctx, i, &contract_with_squeeze_and_bias)) {
TF_RETURN_IF_ERROR(
@ -1612,6 +1602,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Remap Conv2D+FusedBatchNorm into the _FusedConv2D;
ContractionWithBatchNorm contract_with_batch_norm;
if (allow_non_differentiable_rewrites &&
FindConv2DWithBatchNorm(ctx, i, &contract_with_batch_norm)) {
TF_RETURN_IF_ERROR(AddFusedConv2DNode(&ctx, contract_with_batch_norm,
@ -1621,6 +1612,8 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Remap Conv2D+FusedBatchNorm+Activation into the _FusedConv2D;
ContractionWithBatchNormAndActivation
contract_with_batch_norm_and_activation;
if (allow_non_differentiable_rewrites &&
FindConv2DWithBatchNormAndActivation(
ctx, i, &contract_with_batch_norm_and_activation)) {
@ -1644,6 +1637,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
}
// Remap FusedBatchNorm+<SideInput>+<Activation> into the _FusedBatchNormEx.
FusedBatchNormEx fused_batch_norm_ex;
if (allow_non_differentiable_rewrites &&
FindFusedBatchNormEx(ctx, i, &fused_batch_norm_ex)) {
TF_RETURN_IF_ERROR(AddFusedBatchNormExNode(
@ -1653,6 +1647,7 @@ Status Remapper::Optimize(Cluster* cluster, const GrapplerItem& item,
// During inference, most of the inputs to FusedBatchNorm are constant, and
// we can therefore replace the op with a much cheaper set of primitives.
FusedBatchNorm fused_batch_norm;
if (FindFusedBatchNorm(ctx, i, &fused_batch_norm)) {
TF_RETURN_IF_ERROR(AddBatchNormNodes(&ctx, fused_batch_norm));
continue;

View File

@ -112,42 +112,62 @@ TEST_F(RemapperTest, FusedBatchNormNCHW) {
ASSERT_EQ(tensors_expected.size(), 1);
auto tensors = EvaluateNodes(output, item.fetch);
ASSERT_EQ(tensors.size(), 1);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-5);
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-3);
}
}
TEST_F(RemapperTest, FuseBatchNormWithRelu) {
using ::tensorflow::ops::Placeholder;
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithRelu test. It requires "
"CUDNN_VERSION >= 7402.";
#else
for (bool is_training : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
if (is_training) {
LOG(INFO) << "Skip FuseBatchNormWithRelu"
<< "[is_training=" << is_training << "] "
<< "test. It requires CUDNN_VERSION >= 7402.";
continue;
}
#endif
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
#if !defined(GOOGLE_CUDA)
if (!is_training) {
LOG(INFO) << "Skip FuseBatchNormWithRelu"
<< "[is_training=" << is_training << "]";
continue;
}
#endif
const int num_channels = 24;
TensorShape channel_shape({num_channels});
TensorShape empty_shape({0});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape({2, 8, 8, num_channels}));
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat("NHWC"));
auto relu = ops::Relu(s.WithOpName("relu"), fbn.y);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
GrapplerItem item;
item.fetch = {"fetch"};
@ -199,47 +219,67 @@ TEST_F(RemapperTest, FuseBatchNormWithRelu) {
ASSERT_EQ(tensors.size(), 1);
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
}
}
TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
using ::tensorflow::ops::Placeholder;
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu test. It requires "
"CUDNN_VERSION >= 7402.";
#else
for (bool is_training : {true, false}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({2, 8, 8, 24});
auto channels_shape = ops::Placeholder::Shape({24});
#if !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
if (is_training) {
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
<< "[is_training=" << is_training << "] "
<< "test. It requires CUDNN_VERSION >= 7402.";
continue;
}
#endif
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
#if !defined(GOOGLE_CUDA)
if (!is_training) {
LOG(INFO) << "Skip FuseBatchNormWithAddAndRelu"
<< "[is_training=" << is_training << "]";
continue;
}
#endif
const int num_channels = 24;
TensorShape input_shape({2, 8, 8, num_channels});
TensorShape channel_shape({num_channels});
TensorShape empty_shape({0});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT,
ops::Placeholder::Shape(input_shape));
auto input_cast = ops::Cast(s.WithOpName("input_cast"), input, DT_HALF);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT, channels_shape);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT, channels_shape);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT, channels_shape);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT, channels_shape);
auto side_input =
Placeholder(s.WithOpName("side_input"), DT_FLOAT, input_shape);
auto scale = Placeholder(s.WithOpName("scale"), DT_FLOAT);
auto offset = Placeholder(s.WithOpName("offset"), DT_FLOAT);
auto mean = Placeholder(s.WithOpName("mean"), DT_FLOAT);
auto var = Placeholder(s.WithOpName("var"), DT_FLOAT);
auto side_input = Placeholder(s.WithOpName("side_input"), DT_FLOAT);
auto side_input_cast =
ops::Cast(s.WithOpName("side_input_cast"), side_input, DT_HALF);
float epsilon = 0.1f;
auto fbn = ops::FusedBatchNormV3(
s.WithOpName("fused_batch_norm"), input_cast, scale, offset, mean, var,
ops::FusedBatchNormV3::IsTraining(true).Epsilon(epsilon).DataFormat(
"NHWC"));
ops::FusedBatchNormV3::IsTraining(is_training)
.Epsilon(epsilon)
.DataFormat("NHWC"));
auto add = ops::Add(s.WithOpName("add"), fbn.y, side_input_cast);
auto relu = ops::Relu(s.WithOpName("relu"), add);
auto fetch = ops::Identity(s.WithOpName("fetch"), relu);
auto input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto scale_t = GenerateRandomTensor<DT_FLOAT>({24});
auto offset_t = GenerateRandomTensor<DT_FLOAT>({24});
auto mean_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto var_t = GenerateRandomTensor<DT_FLOAT>({0}); // empty for training
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, 24});
auto input_t = GenerateRandomTensor<DT_FLOAT>(input_shape);
auto scale_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto offset_t = GenerateRandomTensor<DT_FLOAT>(channel_shape);
auto mean_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto var_t = GenerateRandomTensor<DT_FLOAT>(is_training ? empty_shape
: channel_shape);
auto side_input_t = GenerateRandomTensor<DT_FLOAT>({2, 8, 8, num_channels});
GrapplerItem item;
item.fetch = {"fetch"};
@ -290,7 +330,7 @@ TEST_F(RemapperTest, FuseBatchNormWithAddAndRelu) {
ASSERT_EQ(tensors.size(), 1);
test::ExpectClose(tensors[0], tensors_expected[0], 1e-2, /*rtol=*/1e-2);
}
#endif // !defined(GOOGLE_CUDA) || !(CUDNN_VERSION >= 7402)
}
}
TEST_F(RemapperTest, FuseConv2DWithBias) {