From 428dbbfd8a12d4e3239f7d34934e2074e94aa854 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Wed, 2 Aug 2017 09:12:04 -0700 Subject: [PATCH] [BatchNormRewriter] Add option to use multi-output fusion. Having an option to opt-out multi-output fusion has two benefits: - This enables CPU backend, which doesn't have fusion support, to handle the rewritten nodes. - This helps our benchmarks to turn off multi-output fusion. RELNOTES: n/a PiperOrigin-RevId: 163990633 --- .../xla/service/batchnorm_rewriter.cc | 53 +++++++++++-------- .../compiler/xla/service/batchnorm_rewriter.h | 7 ++- tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 5 ++ tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../compiler/xla/service/gpu/gpu_compiler.cc | 7 +++ .../xla/tests/batch_normalization_test.cc | 45 +++++++--------- 7 files changed, 67 insertions(+), 52 deletions(-) diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index a68db6ae8dd..2a245c9fadd 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -60,7 +60,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { // Runs the visitor on a computation. static bool Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_grad_op); + bool rewrite_grad_op, bool use_fusion); // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } @@ -70,10 +70,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { private: explicit BatchNormRewriterVisitor(HloComputation* computation, bool rewrite_training_op, - bool rewrite_grad_op) + bool rewrite_grad_op, bool use_fusion) : computation_(computation), rewrite_training_op_(rewrite_training_op), - rewrite_grad_op_(rewrite_grad_op) {} + rewrite_grad_op_(rewrite_grad_op), + use_fusion_(use_fusion) {} HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, HloOpcode opcode) { @@ -94,6 +95,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { bool rewrite_training_op_; bool rewrite_grad_op_; + bool use_fusion_; // Whether rewrite has occurred. bool changed_ = false; @@ -124,10 +126,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { bool BatchNormRewriterVisitor::Run(HloComputation* computation, bool rewrite_training_op, - bool rewrite_grad_op) { + bool rewrite_grad_op, bool use_fusion) { BatchNormRewriterVisitor visitor(computation, /*rewrite_training_op=*/rewrite_training_op, - /*rewrite_grad_op=*/rewrite_grad_op); + /*rewrite_grad_op=*/rewrite_grad_op, + /*use_fusion=*/use_fusion); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -189,18 +192,20 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( add_reduce_computation)); // Fuse two parallel reduces together to improve performance. - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({sum, squared_sum})); + if (use_fusion_) { + auto tuple = computation_->AddInstruction( + HloInstruction::CreateTuple({sum, squared_sum})); - auto fused = computation_->CreateFusionInstruction( - {tuple, sum, squared_sum, operand_squared}, - HloInstruction::FusionKind::kInput); + auto fused = computation_->CreateFusionInstruction( + {tuple, sum, squared_sum, operand_squared}, + HloInstruction::FusionKind::kInput); - sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - squared_sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + squared_sum = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + } // E[X]. auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -352,17 +357,19 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({grad_scale, grad_beta})); + if (use_fusion_) { + auto tuple = computation_->AddInstruction( + HloInstruction::CreateTuple({grad_scale, grad_beta})); - auto fused = computation_->CreateFusionInstruction( - {tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput); + auto fused = computation_->CreateFusionInstruction( + {tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput); - grad_scale = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + grad_scale = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - grad_beta = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + grad_beta = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + } TF_CHECK_OK(ReplaceWithNewInstruction( batch_norm, @@ -385,7 +392,7 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) { } for (auto& comp : computations) { if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, - rewrite_grad_op_)) { + rewrite_grad_op_, use_fusion_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_rewriter.h index 6d176f4849a..d3ffb31032e 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.h @@ -28,10 +28,12 @@ namespace xla { // logic. class BatchNormRewriter : public HloPassInterface { public: + // When use_fusion is set, a multi-output fusion node is created. BatchNormRewriter(bool rewrite_training_op = false, - bool rewrite_grad_op = false) + bool rewrite_grad_op = false, bool use_fusion = true) : rewrite_training_op_(rewrite_training_op), - rewrite_grad_op_(rewrite_grad_op) {} + rewrite_grad_op_(rewrite_grad_op), + use_fusion_(use_fusion) {} ~BatchNormRewriter() = default; tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; } @@ -42,6 +44,7 @@ class BatchNormRewriter : public HloPassInterface { private: bool rewrite_training_op_; bool rewrite_grad_op_; + bool use_fusion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 06647a7bbc4..0adaedd36fd 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -53,6 +53,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batchnorm_rewriter", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:compiler", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 8ee41e8ccf0..710cae45a70 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/copy_insertion.h" @@ -266,6 +267,10 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); + pass.AddPass<BatchNormRewriter>( + /*rewrite_training_op=*/true, + /*rewrite_grad_op=*/true, + /*use_fusion=*/false); pass.AddPass<AlgebraicSimplifier>( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index cdd7c8187c9..b7060560b40 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -417,6 +417,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:batchnorm_rewriter", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:compiler", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 8abfcb82017..cf83a43a749 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -130,6 +131,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module, { auto& pass = pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification"); + // TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs + // instead. + pass.AddPass<BatchNormRewriter>( + /*rewrite_training_op=*/true, + /*rewrite_grad_op=*/true, + /*use_fusion=*/false); pass.AddPass<AlgebraicSimplifier>( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 4fe45a3d9c5..ada647ca05f 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -210,8 +210,7 @@ class BatchNormTest : public ClientLibraryTestBase, public ::testing::WithParamInterface<BatchNormTestParam> { }; -// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. -XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) { +XLA_TEST_P(BatchNormTest, RandomizedTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector<int64>& bounds = GetParam().bounds; @@ -264,15 +263,15 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) { var[i] = square_mean[i] - mean_square[i]; } - Array4D<float> mean_4D = + Array4D<float> mean4D = *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); - auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); - auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); - auto offset_4D = + auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); + auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); + auto offset4D = *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index); - auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean_4D, var_4D, - scale_4D, offset_4D, epsilon); + auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D, + scale4D, offset4D, epsilon); auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized); @@ -307,9 +306,7 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) { ErrorSpec(0.01, 1)); } -// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. -XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU( - DISABLED_ON_GPU(RandomizedGradTests)))) { +XLA_TEST_P(BatchNormTest, RandomizedGradTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector<int64>& bounds = GetParam().bounds; @@ -365,23 +362,23 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU( var[i] = square_mean[i] - mean_square[i]; } - Array4D<float> mean_4D = + Array4D<float> mean4D = *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index); - auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); - auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); + auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index); + auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index); auto var_add_epsilon = *ReferenceUtil::MapArray4D( - var_4D, [epsilon](float a) { return std::sqrt(a + epsilon); }); + var4D, [epsilon](float a) { return std::sqrt(a + epsilon); }); auto grad_output_times_var = *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon, [](float a, float b) { return a * b; }); auto grad_activation = *ReferenceUtil::MapArray4D( - grad_output_times_var, scale_4D, [](float a, float b) { return a * b; }); + grad_output_times_var, scale4D, [](float a, float b) { return a * b; }); auto activation_shifted = *ReferenceUtil::MapArray4D( - input_array, mean_4D, [](float a, float b) { return a - b; }); + input_array, mean4D, [](float a, float b) { return a - b; }); auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(grad_output_times_var, activation_shifted, @@ -460,8 +457,7 @@ INSTANTIATE_TEST_CASE_P( // is correct after relayout. BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100})); -// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. -XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) { +XLA_TEST_F(BatchNormTest, BasicTraining) { const int kFeatureIndex = 3; ComputationBuilder builder(client_, TestName()); @@ -485,8 +481,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } -// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. -XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) { +XLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) { const int kFeatureIndex = 2; ComputationBuilder builder(client_, TestName()); @@ -510,7 +505,6 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) { ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); } -// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { // Use 0 dimension as feature, tests layout analyzer. const int kFeatureIndex = 0; @@ -543,8 +537,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { ErrorSpec(0.1)); } -// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20. -XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) { +XLA_TEST_F(BatchNormTest, LargeEpsilonTest) { // Test the correctness of choosing a large epsilon value. const int kFeatureIndex = 2; ComputationBuilder builder(client_, TestName()); @@ -577,9 +570,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) { ErrorSpec(0.1)); } -// TODO(b/62764704): Implement on CPU and GPU. Disabled on 2017-07-11. -XLA_TEST_F(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU( - DISABLED_ON_GPU(BatchNormGradBasic)))) { +XLA_TEST_F(BatchNormTest, BatchNormGradBasic) { const int kFeatureIndex = 2; ComputationBuilder builder(client_, TestName());