From 45dbb0a02d2fa5b1eb20836fe854e19842b9593f Mon Sep 17 00:00:00 2001 From: Blake Hechtman Date: Fri, 24 Mar 2017 12:17:38 -0800 Subject: [PATCH] Strength reduce Dot into broadcasting multiply and reduce. Also optimizes transposes and reshapes that feed reductions. Change: 151162327 --- .../xla/service/algebraic_simplifier.cc | 223 +++++++++++++++++- .../compiler/xla/service/hlo_computation.h | 1 + .../compiler/xla/tests/dot_operation_test.cc | 9 + tensorflow/compiler/xla/tests/reduce_test.cc | 66 ++++++ 4 files changed, 298 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 9fe2d0e6b61..d171a8dfffe 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -76,6 +76,24 @@ bool ReshapeIsBitcast( return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) && valid_bitcast_callback(operand->shape(), reshape->shape()); } + +// Adds a scalar computation to the module to enable optimizations with dot +// converting into reduction. +HloComputation* CreateScalarBinaryComputation(HloModule* module, + PrimitiveType primitive_type, + HloOpcode opcode) { + HloComputation::Builder b("scalar computation"); + auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + auto scalar_op = b.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), + opcode, scalar_lhs, scalar_rhs)); + HloComputation* scalar_computation = + module->AddEmbeddedComputation(b.Build(scalar_op)); + return scalar_computation; +} } // namespace // AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain @@ -105,6 +123,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) override; + Status HandleDot(HloInstruction* dot, HloInstruction* lhs, + HloInstruction* rhs) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; @@ -304,6 +325,140 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, + HloInstruction* lhs, + HloInstruction* rhs) { + // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or + // below. + if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || + ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + return Status::OK(); + } + + // Replace a zero element dot with a broadcast of the constant 0. + if (ShapeUtil::HasZeroElements(dot->shape()) || + ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + changed_ = true; + return computation_->ReplaceWithNewInstruction( + dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); + } + + // Simplify dot(transpose(a), transpose(b)) to tranpose(dot(b,a)). + if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, + rhs->mutable_operand(0), lhs->mutable_operand(0))); + changed_ = true; + return computation_->ReplaceWithNewInstruction( + dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + } + + // Simplify outer product into multiply with implicit broadcasting. + // + // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) + if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { + changed_ = true; + return computation_->ReplaceWithNewInstruction( + dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, + lhs, rhs)); + } + + // The following graph transformations take Dots where at least one input is a + // vector or has a degenerate dimension and converts it into a multiply and + // reduce. This should enable more fusion than leaving the nodes as Dot + // operations. + + // Strength reduce dot(a[K] , b[K]) = + // reshape(result.shape, + // reduce_sum(multiply(a, b), {0})) + if (ShapeUtil::Rank(rhs->shape()) == 1 && + ShapeUtil::Rank(lhs->shape()) == 1) { + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, lhs, rhs)); + HloComputation* add_reduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, + {0}, add_reduce_computation)); + changed_ = true; + return computation_->ReplaceWithNewInstruction( + dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + } + + // Strength reduce dot(a[1, K], b) = + // reshape(result.shape, + // reduce_sum( + // multiply(broadcast(reshape(a, [K]), {0}), b), + // {0}) + // ) + // ) + if (ShapeUtil::Rank(lhs->shape()) == 1 || + (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) { + auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(lhs->shape().element_type(), + {ShapeUtil::ElementsIn(lhs->shape())}), + lhs)); + HloComputation* add_reduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction* reduce; + if (ShapeUtil::Rank(rhs->shape()) == 1) { + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); + reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, + {0}, add_reduce_computation)); + } else { + new_lhs = computation_->AddInstruction( + HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0})); + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); + + reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), + {rhs->shape().dimensions(1)}), + multiply, zero, {0}, add_reduce_computation)); + } + changed_ = true; + return computation_->ReplaceWithNewInstruction( + dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + } + + // Strength reduce dot(a, b[K, 1]) = + // reshape(result.shape, + // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) + // ) + if (ShapeUtil::Rank(rhs->shape()) == 1 || + (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) { + auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(rhs->shape().element_type(), + {ShapeUtil::ElementsIn(rhs->shape())}), + rhs)); + new_rhs = computation_->AddInstruction( + HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1})); + auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( + lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs)); + HloComputation* add_reduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + auto zero = computation_->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(dot->shape().element_type(), + {lhs->shape().dimensions(0)}), + multiply, zero, {1}, add_reduce_computation)); + changed_ = true; + return computation_->ReplaceWithNewInstruction( + dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { @@ -858,8 +1013,74 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, Status AlgebraicSimplifierVisitor::HandleReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { + if (ShapeUtil::HasZeroElements(arg->shape()) || + ShapeUtil::HasZeroElements(reduce->shape())) { + return computation_->ReplaceWithNewInstruction( + reduce, + HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); + return Status::OK(); + } + // A Transpose feeding a reduce can simply permute the reduction dimensions + // field. + if (arg->opcode() == HloOpcode::kTranspose) { + auto transpose_dimensions = arg->dimensions(); + std::vector new_reduce_dimensions; + for (auto dim : dimensions) { + new_reduce_dimensions.push_back(transpose_dimensions[dim]); + } + return computation_->ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReduce( + reduce->shape(), arg->mutable_operand(0), init_value, + new_reduce_dimensions, function)); + } + + // A reshape that collapses multiple dimensions into a dimension being reduced + // can just reduce all of those dimensions instead of doing a collapsing + // reshape before a reduction. + if (arg->opcode() == HloOpcode::kReshape) { + std::vector> unmodified_dims = + ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(), + arg->shape()); + std::vector arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true); + std::vector arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false); + for (auto dim : dimensions) { + arg_dim_in_output[dim] = false; + } + for (auto dim_pair : unmodified_dims) { + arg_dim_unmodified[dim_pair.second] = true; + } + // The goal is to verify that all dimensions that are not removed in the + // reduce are unmodified by the reshape. For example: + // reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2]) + bool can_move_reshape_into_reduce = true; + for (int64 i = 0; i < arg_dim_in_output.size(); ++i) { + if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) { + can_move_reshape_into_reduce = false; + } + } + if (can_move_reshape_into_reduce) { + changed_ = true; + std::unordered_set dimensions_not_to_reduce; + for (auto dim_pair : unmodified_dims) { + if (arg_dim_in_output[dim_pair.second]) { + dimensions_not_to_reduce.insert(dim_pair.first); + } + } + std::vector new_reduce_dimensions; + for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) { + if (dimensions_not_to_reduce.count(i) == 0) { + new_reduce_dimensions.push_back(i); + } + } + return computation_->ReplaceWithNewInstruction( + reduce, HloInstruction::CreateReduce( + reduce->shape(), arg->mutable_operand(0), init_value, + new_reduce_dimensions, function)); + } + } if (ShapeUtil::ElementsIn(reduce->shape()) == - ShapeUtil::ElementsIn(arg->shape())) { + ShapeUtil::ElementsIn(arg->shape()) || + ShapeUtil::HasZeroElements(arg->shape())) { auto reshape = computation_->AddInstruction( HloInstruction::CreateReshape(reduce->shape(), arg)); changed_ = true; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index fc02b2c4ef5..08cde9d0321 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -194,6 +194,7 @@ class HloComputation { // Set/get the module containing this computation. void set_parent(HloModule* module) { parent_ = module; } const HloModule* parent() const { return parent_; } + HloModule* parent() { return parent_; } // Visit every node in the computation in DFS post-order with the given // visitor. This is similar to calling HloInstruction::Accept on the root of diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 45df8114534..64d2af3c393 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -67,6 +67,15 @@ XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { ComputeAndCompareR0(&builder, 0.0, {}, error_spec_); } +XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR2({{3.0, 4.0}}); + auto rhs = builder.ConstantR1({3.0, 4.0}); + auto result = builder.Dot(lhs, rhs); + + ComputeAndCompareR1(&builder, {25.0}, {}, error_spec_); +} + template void DotOperationTest::TestOneElementVectorDot() { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 3d61b624dc8..34fce21758b 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -320,6 +320,72 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { ErrorSpec(0.01, 1e-4)); } +XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { + const int64 rows = 111, cols = 50; + + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + auto log_ = builder.Log(input); + auto transpose = builder.Transpose(log_, {1, 0}); + builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1}); + + Array2D input_data(rows, cols); + input_data.FillRandom(3.14f, 0.04); + std::unique_ptr input_literal = + LiteralUtil::CreateR2FromArray2D(input_data); + input_literal = + LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::vector expected; + for (int64 colno = 0; colno < cols; ++colno) { + float column_sum = 0; + for (int64 rowno = 0; rowno < rows; ++rowno) { + column_sum += log(input_data(rowno, colno)); + } + expected.push_back(column_sum); + } + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.01, 1e-4)); +} + +XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { + const int64 rows = 111, cols = 50; + + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto zero = builder.ConstantR0(0.0); + auto log_ = builder.Log(input); + auto reshape = builder.Reshape(log_, {rows, cols}); + builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + Array3D input_data(rows, 2, cols / 2); + input_data.FillRandom(3.14f, 0.04); + std::unique_ptr input_literal = + LiteralUtil::CreateR3FromArray3D(input_data); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::vector expected; + for (int64 major = 0; major < 2; ++major) { + for (int64 colno = 0; colno < cols / 2; ++colno) { + float column_sum = 0; + for (int64 rowno = 0; rowno < rows; ++rowno) { + column_sum += log(input_data(rowno, major, colno)); + } + expected.push_back(column_sum); + } + } + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}, + ErrorSpec(0.01, 1e-4)); +} + struct BoundsLayout { std::vector bounds; std::vector layout;