From 428dbbfd8a12d4e3239f7d34934e2074e94aa854 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Wed, 2 Aug 2017 09:12:04 -0700
Subject: [PATCH] [BatchNormRewriter] Add option to use multi-output fusion.

Having an option to opt-out multi-output fusion has two benefits:

- This enables CPU backend, which doesn't have fusion support, to handle the rewritten nodes.

- This helps our benchmarks to turn off multi-output fusion.

RELNOTES: n/a
PiperOrigin-RevId: 163990633
---
 .../xla/service/batchnorm_rewriter.cc         | 53 +++++++++++--------
 .../compiler/xla/service/batchnorm_rewriter.h |  7 ++-
 tensorflow/compiler/xla/service/cpu/BUILD     |  1 +
 .../compiler/xla/service/cpu/cpu_compiler.cc  |  5 ++
 tensorflow/compiler/xla/service/gpu/BUILD     |  1 +
 .../compiler/xla/service/gpu/gpu_compiler.cc  |  7 +++
 .../xla/tests/batch_normalization_test.cc     | 45 +++++++---------
 7 files changed, 67 insertions(+), 52 deletions(-)

diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
index a68db6ae8dd..2a245c9fadd 100644
--- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
@@ -60,7 +60,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
 
   // Runs the visitor on a computation.
   static bool Run(HloComputation* computation, bool rewrite_training_op,
-                  bool rewrite_grad_op);
+                  bool rewrite_grad_op, bool use_fusion);
 
   // Returns whether any batch norm ops were rewritten.
   const bool changed() const { return changed_; }
@@ -70,10 +70,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
  private:
   explicit BatchNormRewriterVisitor(HloComputation* computation,
                                     bool rewrite_training_op,
-                                    bool rewrite_grad_op)
+                                    bool rewrite_grad_op, bool use_fusion)
       : computation_(computation),
         rewrite_training_op_(rewrite_training_op),
-        rewrite_grad_op_(rewrite_grad_op) {}
+        rewrite_grad_op_(rewrite_grad_op),
+        use_fusion_(use_fusion) {}
 
   HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type,
                                              HloOpcode opcode) {
@@ -94,6 +95,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
 
   bool rewrite_training_op_;
   bool rewrite_grad_op_;
+  bool use_fusion_;
 
   // Whether rewrite has occurred.
   bool changed_ = false;
@@ -124,10 +126,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
 
 bool BatchNormRewriterVisitor::Run(HloComputation* computation,
                                    bool rewrite_training_op,
-                                   bool rewrite_grad_op) {
+                                   bool rewrite_grad_op, bool use_fusion) {
   BatchNormRewriterVisitor visitor(computation,
                                    /*rewrite_training_op=*/rewrite_training_op,
-                                   /*rewrite_grad_op=*/rewrite_grad_op);
+                                   /*rewrite_grad_op=*/rewrite_grad_op,
+                                   /*use_fusion=*/use_fusion);
   TF_CHECK_OK(computation->Accept(&visitor));
   return visitor.changed_;
 }
@@ -189,18 +192,20 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
       add_reduce_computation));
 
   // Fuse two parallel reduces together to improve performance.
-  auto tuple = computation_->AddInstruction(
-      HloInstruction::CreateTuple({sum, squared_sum}));
+  if (use_fusion_) {
+    auto tuple = computation_->AddInstruction(
+        HloInstruction::CreateTuple({sum, squared_sum}));
 
-  auto fused = computation_->CreateFusionInstruction(
-      {tuple, sum, squared_sum, operand_squared},
-      HloInstruction::FusionKind::kInput);
+    auto fused = computation_->CreateFusionInstruction(
+        {tuple, sum, squared_sum, operand_squared},
+        HloInstruction::FusionKind::kInput);
 
-  sum = computation_->AddInstruction(
-      HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
+    sum = computation_->AddInstruction(
+        HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
 
-  squared_sum = computation_->AddInstruction(
-      HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
+    squared_sum = computation_->AddInstruction(
+        HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
+  }
 
   // E[X].
   auto mean = computation_->AddInstruction(HloInstruction::CreateBinary(
@@ -352,17 +357,19 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
       feature_shape, grad_output, zero, dimensions_without_feature,
       add_reduce_computation));
 
-  auto tuple = computation_->AddInstruction(
-      HloInstruction::CreateTuple({grad_scale, grad_beta}));
+  if (use_fusion_) {
+    auto tuple = computation_->AddInstruction(
+        HloInstruction::CreateTuple({grad_scale, grad_beta}));
 
-  auto fused = computation_->CreateFusionInstruction(
-      {tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput);
+    auto fused = computation_->CreateFusionInstruction(
+        {tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput);
 
-  grad_scale = computation_->AddInstruction(
-      HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
+    grad_scale = computation_->AddInstruction(
+        HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
 
-  grad_beta = computation_->AddInstruction(
-      HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
+    grad_beta = computation_->AddInstruction(
+        HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
+  }
 
   TF_CHECK_OK(ReplaceWithNewInstruction(
       batch_norm,
@@ -385,7 +392,7 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
   }
   for (auto& comp : computations) {
     if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
-                                      rewrite_grad_op_)) {
+                                      rewrite_grad_op_, use_fusion_)) {
       changed = true;
     }
   }
diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_rewriter.h
index 6d176f4849a..d3ffb31032e 100644
--- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h
+++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.h
@@ -28,10 +28,12 @@ namespace xla {
 // logic.
 class BatchNormRewriter : public HloPassInterface {
  public:
+  // When use_fusion is set, a multi-output fusion node is created.
   BatchNormRewriter(bool rewrite_training_op = false,
-                    bool rewrite_grad_op = false)
+                    bool rewrite_grad_op = false, bool use_fusion = true)
       : rewrite_training_op_(rewrite_training_op),
-        rewrite_grad_op_(rewrite_grad_op) {}
+        rewrite_grad_op_(rewrite_grad_op),
+        use_fusion_(use_fusion) {}
   ~BatchNormRewriter() = default;
   tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; }
 
@@ -42,6 +44,7 @@ class BatchNormRewriter : public HloPassInterface {
  private:
   bool rewrite_training_op_;
   bool rewrite_grad_op_;
+  bool use_fusion_;
 };
 
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 06647a7bbc4..0adaedd36fd 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -53,6 +53,7 @@ cc_library(
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla:xla_data_proto",
         "//tensorflow/compiler/xla/service:algebraic_simplifier",
+        "//tensorflow/compiler/xla/service:batchnorm_rewriter",
         "//tensorflow/compiler/xla/service:buffer_assignment",
         "//tensorflow/compiler/xla/service:buffer_liveness",
         "//tensorflow/compiler/xla/service:compiler",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 8ee41e8ccf0..710cae45a70 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -42,6 +42,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/protobuf_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
 #include "tensorflow/compiler/xla/service/copy_insertion.h"
@@ -266,6 +267,10 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
   {
     auto& pass =
         pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
+    pass.AddPass<BatchNormRewriter>(
+        /*rewrite_training_op=*/true,
+        /*rewrite_grad_op=*/true,
+        /*use_fusion=*/false);
     pass.AddPass<AlgebraicSimplifier>(
         /*is_layout_sensitive=*/false,
         [](const Shape&, const Shape&) { return false; },
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index cdd7c8187c9..b7060560b40 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -417,6 +417,7 @@ cc_library(
         "//tensorflow/compiler/xla:types",
         "//tensorflow/compiler/xla:util",
         "//tensorflow/compiler/xla/service:algebraic_simplifier",
+        "//tensorflow/compiler/xla/service:batchnorm_rewriter",
         "//tensorflow/compiler/xla/service:buffer_assignment",
         "//tensorflow/compiler/xla/service:buffer_liveness",
         "//tensorflow/compiler/xla/service:compiler",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
index 8abfcb82017..cf83a43a749 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
@@ -26,6 +26,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/protobuf_util.h"
 #include "tensorflow/compiler/xla/ptr_util.h"
 #include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
+#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
 #include "tensorflow/compiler/xla/service/buffer_liveness.h"
 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@@ -130,6 +131,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
     {
       auto& pass =
           pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
+      // TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs
+      // instead.
+      pass.AddPass<BatchNormRewriter>(
+          /*rewrite_training_op=*/true,
+          /*rewrite_grad_op=*/true,
+          /*use_fusion=*/false);
       pass.AddPass<AlgebraicSimplifier>(
           /*is_layout_sensitive=*/false,
           [](const Shape&, const Shape&) { return false; });
diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
index 4fe45a3d9c5..ada647ca05f 100644
--- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc
+++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc
@@ -210,8 +210,7 @@ class BatchNormTest : public ClientLibraryTestBase,
                       public ::testing::WithParamInterface<BatchNormTestParam> {
 };
 
-// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
-XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
+XLA_TEST_P(BatchNormTest, RandomizedTests) {
   float epsilon = 0.001;
   ComputationBuilder builder(client_, TestName());
   const std::vector<int64>& bounds = GetParam().bounds;
@@ -264,15 +263,15 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
     var[i] = square_mean[i] - mean_square[i];
   }
 
-  Array4D<float> mean_4D =
+  Array4D<float> mean4D =
       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
-  auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
-  auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
-  auto offset_4D =
+  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
+  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
+  auto offset4D =
       *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
 
-  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean_4D, var_4D,
-                                                scale_4D, offset_4D, epsilon);
+  auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
+                                                scale4D, offset4D, epsilon);
 
   auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
 
@@ -307,9 +306,7 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
       ErrorSpec(0.01, 1));
 }
 
-// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
-XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
-                              DISABLED_ON_GPU(RandomizedGradTests)))) {
+XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
   float epsilon = 0.001;
   ComputationBuilder builder(client_, TestName());
   const std::vector<int64>& bounds = GetParam().bounds;
@@ -365,23 +362,23 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
     var[i] = square_mean[i] - mean_square[i];
   }
 
-  Array4D<float> mean_4D =
+  Array4D<float> mean4D =
       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
-  auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
-  auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
+  auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
+  auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
 
   auto var_add_epsilon = *ReferenceUtil::MapArray4D(
-      var_4D, [epsilon](float a) { return std::sqrt(a + epsilon); });
+      var4D, [epsilon](float a) { return std::sqrt(a + epsilon); });
 
   auto grad_output_times_var =
       *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
                                  [](float a, float b) { return a * b; });
 
   auto grad_activation = *ReferenceUtil::MapArray4D(
-      grad_output_times_var, scale_4D, [](float a, float b) { return a * b; });
+      grad_output_times_var, scale4D, [](float a, float b) { return a * b; });
 
   auto activation_shifted = *ReferenceUtil::MapArray4D(
-      input_array, mean_4D, [](float a, float b) { return a - b; });
+      input_array, mean4D, [](float a, float b) { return a - b; });
 
   auto grad_scale_before_reduction =
       *ReferenceUtil::MapArray4D(grad_output_times_var, activation_shifted,
@@ -460,8 +457,7 @@ INSTANTIATE_TEST_CASE_P(
                       // is correct after relayout.
                       BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100}));
 
-// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
-XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) {
+XLA_TEST_F(BatchNormTest, BasicTraining) {
   const int kFeatureIndex = 3;
   ComputationBuilder builder(client_, TestName());
 
@@ -485,8 +481,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) {
   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 }
 
-// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
-XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) {
+XLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) {
   const int kFeatureIndex = 2;
   ComputationBuilder builder(client_, TestName());
 
@@ -510,7 +505,6 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) {
   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
 }
 
-// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
 XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
   // Use 0 dimension as feature, tests layout analyzer.
   const int kFeatureIndex = 0;
@@ -543,8 +537,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
                          ErrorSpec(0.1));
 }
 
-// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
-XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) {
+XLA_TEST_F(BatchNormTest, LargeEpsilonTest) {
   // Test the correctness of choosing a large epsilon value.
   const int kFeatureIndex = 2;
   ComputationBuilder builder(client_, TestName());
@@ -577,9 +570,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) {
                          ErrorSpec(0.1));
 }
 
-// TODO(b/62764704): Implement on CPU and GPU. Disabled on 2017-07-11.
-XLA_TEST_F(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
-                              DISABLED_ON_GPU(BatchNormGradBasic)))) {
+XLA_TEST_F(BatchNormTest, BatchNormGradBasic) {
   const int kFeatureIndex = 2;
   ComputationBuilder builder(client_, TestName());