[BatchNormRewriter] Add option to use multi-output fusion.

Having an option to opt-out multi-output fusion has two benefits:

- This enables CPU backend, which doesn't have fusion support, to handle the rewritten nodes.

- This helps our benchmarks to turn off multi-output fusion.

RELNOTES: n/a
PiperOrigin-RevId: 163990633
This commit is contained in:
A. Unique TensorFlower 2017-08-02 09:12:04 -07:00 committed by Benoit Steiner
parent f3ecb0171f
commit 428dbbfd8a
7 changed files with 67 additions and 52 deletions

View File

@ -60,7 +60,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
// Runs the visitor on a computation.
static bool Run(HloComputation* computation, bool rewrite_training_op,
bool rewrite_grad_op);
bool rewrite_grad_op, bool use_fusion);
// Returns whether any batch norm ops were rewritten.
const bool changed() const { return changed_; }
@ -70,10 +70,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
private:
explicit BatchNormRewriterVisitor(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_grad_op)
bool rewrite_grad_op, bool use_fusion)
: computation_(computation),
rewrite_training_op_(rewrite_training_op),
rewrite_grad_op_(rewrite_grad_op) {}
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type,
HloOpcode opcode) {
@ -94,6 +95,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
bool rewrite_training_op_;
bool rewrite_grad_op_;
bool use_fusion_;
// Whether rewrite has occurred.
bool changed_ = false;
@ -124,10 +126,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
bool BatchNormRewriterVisitor::Run(HloComputation* computation,
bool rewrite_training_op,
bool rewrite_grad_op) {
bool rewrite_grad_op, bool use_fusion) {
BatchNormRewriterVisitor visitor(computation,
/*rewrite_training_op=*/rewrite_training_op,
/*rewrite_grad_op=*/rewrite_grad_op);
/*rewrite_grad_op=*/rewrite_grad_op,
/*use_fusion=*/use_fusion);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
@ -189,18 +192,20 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
add_reduce_computation));
// Fuse two parallel reduces together to improve performance.
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple({sum, squared_sum}));
if (use_fusion_) {
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple({sum, squared_sum}));
auto fused = computation_->CreateFusionInstruction(
{tuple, sum, squared_sum, operand_squared},
HloInstruction::FusionKind::kInput);
auto fused = computation_->CreateFusionInstruction(
{tuple, sum, squared_sum, operand_squared},
HloInstruction::FusionKind::kInput);
sum = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
sum = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
squared_sum = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
squared_sum = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
}
// E[X].
auto mean = computation_->AddInstruction(HloInstruction::CreateBinary(
@ -352,17 +357,19 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
feature_shape, grad_output, zero, dimensions_without_feature,
add_reduce_computation));
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple({grad_scale, grad_beta}));
if (use_fusion_) {
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple({grad_scale, grad_beta}));
auto fused = computation_->CreateFusionInstruction(
{tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput);
auto fused = computation_->CreateFusionInstruction(
{tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput);
grad_scale = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
grad_scale = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
grad_beta = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
grad_beta = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
}
TF_CHECK_OK(ReplaceWithNewInstruction(
batch_norm,
@ -385,7 +392,7 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
}
for (auto& comp : computations) {
if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
rewrite_grad_op_)) {
rewrite_grad_op_, use_fusion_)) {
changed = true;
}
}

View File

@ -28,10 +28,12 @@ namespace xla {
// logic.
class BatchNormRewriter : public HloPassInterface {
public:
// When use_fusion is set, a multi-output fusion node is created.
BatchNormRewriter(bool rewrite_training_op = false,
bool rewrite_grad_op = false)
bool rewrite_grad_op = false, bool use_fusion = true)
: rewrite_training_op_(rewrite_training_op),
rewrite_grad_op_(rewrite_grad_op) {}
rewrite_grad_op_(rewrite_grad_op),
use_fusion_(use_fusion) {}
~BatchNormRewriter() = default;
tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; }
@ -42,6 +44,7 @@ class BatchNormRewriter : public HloPassInterface {
private:
bool rewrite_training_op_;
bool rewrite_grad_op_;
bool use_fusion_;
};
} // namespace xla

View File

@ -53,6 +53,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:batchnorm_rewriter",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:compiler",

View File

@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
@ -266,6 +267,10 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; },

View File

@ -417,6 +417,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:batchnorm_rewriter",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:compiler",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
@ -130,6 +131,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
{
auto& pass =
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
// TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs
// instead.
pass.AddPass<BatchNormRewriter>(
/*rewrite_training_op=*/true,
/*rewrite_grad_op=*/true,
/*use_fusion=*/false);
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });

View File

@ -210,8 +210,7 @@ class BatchNormTest : public ClientLibraryTestBase,
public ::testing::WithParamInterface<BatchNormTestParam> {
};
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
XLA_TEST_P(BatchNormTest, RandomizedTests) {
float epsilon = 0.001;
ComputationBuilder builder(client_, TestName());
const std::vector<int64>& bounds = GetParam().bounds;
@ -264,15 +263,15 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
var[i] = square_mean[i] - mean_square[i];
}
Array4D<float> mean_4D =
Array4D<float> mean4D =
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
auto offset_4D =
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
auto offset4D =
*ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean_4D, var_4D,
scale_4D, offset_4D, epsilon);
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
scale4D, offset4D, epsilon);
auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
@ -307,9 +306,7 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
ErrorSpec(0.01, 1));
}
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
DISABLED_ON_GPU(RandomizedGradTests)))) {
XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
float epsilon = 0.001;
ComputationBuilder builder(client_, TestName());
const std::vector<int64>& bounds = GetParam().bounds;
@ -365,23 +362,23 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
var[i] = square_mean[i] - mean_square[i];
}
Array4D<float> mean_4D =
Array4D<float> mean4D =
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
auto var_add_epsilon = *ReferenceUtil::MapArray4D(
var_4D, [epsilon](float a) { return std::sqrt(a + epsilon); });
var4D, [epsilon](float a) { return std::sqrt(a + epsilon); });
auto grad_output_times_var =
*ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
[](float a, float b) { return a * b; });
auto grad_activation = *ReferenceUtil::MapArray4D(
grad_output_times_var, scale_4D, [](float a, float b) { return a * b; });
grad_output_times_var, scale4D, [](float a, float b) { return a * b; });
auto activation_shifted = *ReferenceUtil::MapArray4D(
input_array, mean_4D, [](float a, float b) { return a - b; });
input_array, mean4D, [](float a, float b) { return a - b; });
auto grad_scale_before_reduction =
*ReferenceUtil::MapArray4D(grad_output_times_var, activation_shifted,
@ -460,8 +457,7 @@ INSTANTIATE_TEST_CASE_P(
// is correct after relayout.
BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100}));
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) {
XLA_TEST_F(BatchNormTest, BasicTraining) {
const int kFeatureIndex = 3;
ComputationBuilder builder(client_, TestName());
@ -485,8 +481,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) {
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) {
XLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) {
const int kFeatureIndex = 2;
ComputationBuilder builder(client_, TestName());
@ -510,7 +505,6 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) {
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
}
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
// Use 0 dimension as feature, tests layout analyzer.
const int kFeatureIndex = 0;
@ -543,8 +537,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
ErrorSpec(0.1));
}
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) {
XLA_TEST_F(BatchNormTest, LargeEpsilonTest) {
// Test the correctness of choosing a large epsilon value.
const int kFeatureIndex = 2;
ComputationBuilder builder(client_, TestName());
@ -577,9 +570,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) {
ErrorSpec(0.1));
}
// TODO(b/62764704): Implement on CPU and GPU. Disabled on 2017-07-11.
XLA_TEST_F(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
DISABLED_ON_GPU(BatchNormGradBasic)))) {
XLA_TEST_F(BatchNormTest, BatchNormGradBasic) {
const int kFeatureIndex = 2;
ComputationBuilder builder(client_, TestName());