[BatchNormRewriter] Add option to use multi-output fusion.
Having an option to opt-out multi-output fusion has two benefits: - This enables CPU backend, which doesn't have fusion support, to handle the rewritten nodes. - This helps our benchmarks to turn off multi-output fusion. RELNOTES: n/a PiperOrigin-RevId: 163990633
This commit is contained in:
parent
f3ecb0171f
commit
428dbbfd8a
@ -60,7 +60,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
// Runs the visitor on a computation.
|
||||
static bool Run(HloComputation* computation, bool rewrite_training_op,
|
||||
bool rewrite_grad_op);
|
||||
bool rewrite_grad_op, bool use_fusion);
|
||||
|
||||
// Returns whether any batch norm ops were rewritten.
|
||||
const bool changed() const { return changed_; }
|
||||
@ -70,10 +70,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
private:
|
||||
explicit BatchNormRewriterVisitor(HloComputation* computation,
|
||||
bool rewrite_training_op,
|
||||
bool rewrite_grad_op)
|
||||
bool rewrite_grad_op, bool use_fusion)
|
||||
: computation_(computation),
|
||||
rewrite_training_op_(rewrite_training_op),
|
||||
rewrite_grad_op_(rewrite_grad_op) {}
|
||||
rewrite_grad_op_(rewrite_grad_op),
|
||||
use_fusion_(use_fusion) {}
|
||||
|
||||
HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type,
|
||||
HloOpcode opcode) {
|
||||
@ -94,6 +95,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
bool rewrite_training_op_;
|
||||
bool rewrite_grad_op_;
|
||||
bool use_fusion_;
|
||||
|
||||
// Whether rewrite has occurred.
|
||||
bool changed_ = false;
|
||||
@ -124,10 +126,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
|
||||
|
||||
bool BatchNormRewriterVisitor::Run(HloComputation* computation,
|
||||
bool rewrite_training_op,
|
||||
bool rewrite_grad_op) {
|
||||
bool rewrite_grad_op, bool use_fusion) {
|
||||
BatchNormRewriterVisitor visitor(computation,
|
||||
/*rewrite_training_op=*/rewrite_training_op,
|
||||
/*rewrite_grad_op=*/rewrite_grad_op);
|
||||
/*rewrite_grad_op=*/rewrite_grad_op,
|
||||
/*use_fusion=*/use_fusion);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
@ -189,18 +192,20 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
|
||||
add_reduce_computation));
|
||||
|
||||
// Fuse two parallel reduces together to improve performance.
|
||||
auto tuple = computation_->AddInstruction(
|
||||
HloInstruction::CreateTuple({sum, squared_sum}));
|
||||
if (use_fusion_) {
|
||||
auto tuple = computation_->AddInstruction(
|
||||
HloInstruction::CreateTuple({sum, squared_sum}));
|
||||
|
||||
auto fused = computation_->CreateFusionInstruction(
|
||||
{tuple, sum, squared_sum, operand_squared},
|
||||
HloInstruction::FusionKind::kInput);
|
||||
auto fused = computation_->CreateFusionInstruction(
|
||||
{tuple, sum, squared_sum, operand_squared},
|
||||
HloInstruction::FusionKind::kInput);
|
||||
|
||||
sum = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
|
||||
sum = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
|
||||
|
||||
squared_sum = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
|
||||
squared_sum = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
|
||||
}
|
||||
|
||||
// E[X].
|
||||
auto mean = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
@ -352,17 +357,19 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
|
||||
feature_shape, grad_output, zero, dimensions_without_feature,
|
||||
add_reduce_computation));
|
||||
|
||||
auto tuple = computation_->AddInstruction(
|
||||
HloInstruction::CreateTuple({grad_scale, grad_beta}));
|
||||
if (use_fusion_) {
|
||||
auto tuple = computation_->AddInstruction(
|
||||
HloInstruction::CreateTuple({grad_scale, grad_beta}));
|
||||
|
||||
auto fused = computation_->CreateFusionInstruction(
|
||||
{tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput);
|
||||
auto fused = computation_->CreateFusionInstruction(
|
||||
{tuple, grad_scale, grad_beta}, HloInstruction::FusionKind::kInput);
|
||||
|
||||
grad_scale = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
|
||||
grad_scale = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
|
||||
|
||||
grad_beta = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
|
||||
grad_beta = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
|
||||
}
|
||||
|
||||
TF_CHECK_OK(ReplaceWithNewInstruction(
|
||||
batch_norm,
|
||||
@ -385,7 +392,7 @@ StatusOr<bool> BatchNormRewriter::Run(HloModule* module) {
|
||||
}
|
||||
for (auto& comp : computations) {
|
||||
if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_,
|
||||
rewrite_grad_op_)) {
|
||||
rewrite_grad_op_, use_fusion_)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
@ -28,10 +28,12 @@ namespace xla {
|
||||
// logic.
|
||||
class BatchNormRewriter : public HloPassInterface {
|
||||
public:
|
||||
// When use_fusion is set, a multi-output fusion node is created.
|
||||
BatchNormRewriter(bool rewrite_training_op = false,
|
||||
bool rewrite_grad_op = false)
|
||||
bool rewrite_grad_op = false, bool use_fusion = true)
|
||||
: rewrite_training_op_(rewrite_training_op),
|
||||
rewrite_grad_op_(rewrite_grad_op) {}
|
||||
rewrite_grad_op_(rewrite_grad_op),
|
||||
use_fusion_(use_fusion) {}
|
||||
~BatchNormRewriter() = default;
|
||||
tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; }
|
||||
|
||||
@ -42,6 +44,7 @@ class BatchNormRewriter : public HloPassInterface {
|
||||
private:
|
||||
bool rewrite_training_op_;
|
||||
bool rewrite_grad_op_;
|
||||
bool use_fusion_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -53,6 +53,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:batchnorm_rewriter",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:buffer_liveness",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
|
@ -42,6 +42,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
||||
@ -266,6 +267,10 @@ Status CpuCompiler::RunHloPasses(HloModule* module) {
|
||||
{
|
||||
auto& pass =
|
||||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||
pass.AddPass<BatchNormRewriter>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_grad_op=*/true,
|
||||
/*use_fusion=*/false);
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
/*is_layout_sensitive=*/false,
|
||||
[](const Shape&, const Shape&) { return false; },
|
||||
|
@ -417,6 +417,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:batchnorm_rewriter",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:buffer_liveness",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||
@ -130,6 +131,12 @@ tensorflow::Status OptimizeHloModule(HloModule* hlo_module,
|
||||
{
|
||||
auto& pass =
|
||||
pipeline.AddPass<HloPassFix<HloPassPipeline>>("simplification");
|
||||
// TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs
|
||||
// instead.
|
||||
pass.AddPass<BatchNormRewriter>(
|
||||
/*rewrite_training_op=*/true,
|
||||
/*rewrite_grad_op=*/true,
|
||||
/*use_fusion=*/false);
|
||||
pass.AddPass<AlgebraicSimplifier>(
|
||||
/*is_layout_sensitive=*/false,
|
||||
[](const Shape&, const Shape&) { return false; });
|
||||
|
@ -210,8 +210,7 @@ class BatchNormTest : public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<BatchNormTestParam> {
|
||||
};
|
||||
|
||||
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
|
||||
XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
|
||||
XLA_TEST_P(BatchNormTest, RandomizedTests) {
|
||||
float epsilon = 0.001;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
const std::vector<int64>& bounds = GetParam().bounds;
|
||||
@ -264,15 +263,15 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
|
||||
var[i] = square_mean[i] - mean_square[i];
|
||||
}
|
||||
|
||||
Array4D<float> mean_4D =
|
||||
Array4D<float> mean4D =
|
||||
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
|
||||
auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
|
||||
auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
|
||||
auto offset_4D =
|
||||
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
|
||||
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
|
||||
auto offset4D =
|
||||
*ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
|
||||
|
||||
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean_4D, var_4D,
|
||||
scale_4D, offset_4D, epsilon);
|
||||
auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
|
||||
scale4D, offset4D, epsilon);
|
||||
|
||||
auto expected_normalized = Literal::CreateR4FromArray4D<float>(normalized);
|
||||
|
||||
@ -307,9 +306,7 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_GPU(RandomizedTests)) {
|
||||
ErrorSpec(0.01, 1));
|
||||
}
|
||||
|
||||
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
|
||||
XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
|
||||
DISABLED_ON_GPU(RandomizedGradTests)))) {
|
||||
XLA_TEST_P(BatchNormTest, RandomizedGradTests) {
|
||||
float epsilon = 0.001;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
const std::vector<int64>& bounds = GetParam().bounds;
|
||||
@ -365,23 +362,23 @@ XLA_TEST_P(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
|
||||
var[i] = square_mean[i] - mean_square[i];
|
||||
}
|
||||
|
||||
Array4D<float> mean_4D =
|
||||
Array4D<float> mean4D =
|
||||
*ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
|
||||
auto var_4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
|
||||
auto scale_4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
|
||||
auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
|
||||
auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
|
||||
|
||||
auto var_add_epsilon = *ReferenceUtil::MapArray4D(
|
||||
var_4D, [epsilon](float a) { return std::sqrt(a + epsilon); });
|
||||
var4D, [epsilon](float a) { return std::sqrt(a + epsilon); });
|
||||
|
||||
auto grad_output_times_var =
|
||||
*ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
|
||||
[](float a, float b) { return a * b; });
|
||||
|
||||
auto grad_activation = *ReferenceUtil::MapArray4D(
|
||||
grad_output_times_var, scale_4D, [](float a, float b) { return a * b; });
|
||||
grad_output_times_var, scale4D, [](float a, float b) { return a * b; });
|
||||
|
||||
auto activation_shifted = *ReferenceUtil::MapArray4D(
|
||||
input_array, mean_4D, [](float a, float b) { return a - b; });
|
||||
input_array, mean4D, [](float a, float b) { return a - b; });
|
||||
|
||||
auto grad_scale_before_reduction =
|
||||
*ReferenceUtil::MapArray4D(grad_output_times_var, activation_shifted,
|
||||
@ -460,8 +457,7 @@ INSTANTIATE_TEST_CASE_P(
|
||||
// is correct after relayout.
|
||||
BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100}));
|
||||
|
||||
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
|
||||
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) {
|
||||
XLA_TEST_F(BatchNormTest, BasicTraining) {
|
||||
const int kFeatureIndex = 3;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
@ -485,8 +481,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTraining)) {
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
|
||||
}
|
||||
|
||||
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
|
||||
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) {
|
||||
XLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) {
|
||||
const int kFeatureIndex = 2;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
@ -510,7 +505,6 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(BasicTrainingOnSublane)) {
|
||||
ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
|
||||
}
|
||||
|
||||
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
|
||||
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
|
||||
// Use 0 dimension as feature, tests layout analyzer.
|
||||
const int kFeatureIndex = 0;
|
||||
@ -543,8 +537,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) {
|
||||
ErrorSpec(0.1));
|
||||
}
|
||||
|
||||
// TODO(b/62764704): Implement on GPU. Disabled on 2017-06-20.
|
||||
XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) {
|
||||
XLA_TEST_F(BatchNormTest, LargeEpsilonTest) {
|
||||
// Test the correctness of choosing a large epsilon value.
|
||||
const int kFeatureIndex = 2;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
@ -577,9 +570,7 @@ XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(LargeEpsilonTest)) {
|
||||
ErrorSpec(0.1));
|
||||
}
|
||||
|
||||
// TODO(b/62764704): Implement on CPU and GPU. Disabled on 2017-07-11.
|
||||
XLA_TEST_F(BatchNormTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(
|
||||
DISABLED_ON_GPU(BatchNormGradBasic)))) {
|
||||
XLA_TEST_F(BatchNormTest, BatchNormGradBasic) {
|
||||
const int kFeatureIndex = 2;
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user