Strength reduce Dot into broadcasting multiply and reduce. Also optimizes

transposes and reshapes that feed reductions.
Change: 151162327
This commit is contained in:
Blake Hechtman 2017-03-24 12:17:38 -08:00 committed by TensorFlower Gardener
parent fdf32bc5ed
commit 45dbb0a02d
4 changed files with 298 additions and 1 deletions

View File

@ -76,6 +76,24 @@ bool ReshapeIsBitcast(
return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) &&
valid_bitcast_callback(operand->shape(), reshape->shape());
}
// Adds a scalar computation to the module to enable optimizations with dot
// converting into reduction.
HloComputation* CreateScalarBinaryComputation(HloModule* module,
PrimitiveType primitive_type,
HloOpcode opcode) {
HloComputation::Builder b("scalar computation");
auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "scalar lhs"));
auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "scalar rhs"));
auto scalar_op = b.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
opcode, scalar_lhs, scalar_rhs));
HloComputation* scalar_computation =
module->AddEmbeddedComputation(b.Build(scalar_op));
return scalar_computation;
}
} // namespace
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
@ -105,6 +123,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
HloInstruction* rhs) override;
Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
HloInstruction* rhs) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element,
HloInstruction* operand) override;
@ -304,6 +325,140 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide,
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
HloInstruction* lhs,
HloInstruction* rhs) {
// Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
// below.
if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 ||
ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) {
return Status::OK();
}
// Replace a zero element dot with a broadcast of the constant 0.
if (ShapeUtil::HasZeroElements(dot->shape()) ||
ShapeUtil::HasZeroElements(lhs->shape()) ||
ShapeUtil::HasZeroElements(rhs->shape())) {
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
changed_ = true;
return computation_->ReplaceWithNewInstruction(
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
}
// Simplify dot(transpose(a), transpose(b)) to tranpose(dot(b,a)).
if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary(
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot,
rhs->mutable_operand(0), lhs->mutable_operand(0)));
changed_ = true;
return computation_->ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
// Simplify outer product into multiply with implicit broadcasting.
//
// A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) {
changed_ = true;
return computation_->ReplaceWithNewInstruction(
dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
lhs, rhs));
}
// The following graph transformations take Dots where at least one input is a
// vector or has a degenerate dimension and converts it into a multiply and
// reduce. This should enable more fusion than leaving the nodes as Dot
// operations.
// Strength reduce dot(a[K] , b[K]) =
// reshape(result.shape,
// reduce_sum(multiply(a, b), {0}))
if (ShapeUtil::Rank(rhs->shape()) == 1 &&
ShapeUtil::Rank(lhs->shape()) == 1) {
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
rhs->shape(), HloOpcode::kMultiply, lhs, rhs));
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
computation_->parent(), F32, HloOpcode::kAdd);
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
{0}, add_reduce_computation));
changed_ = true;
return computation_->ReplaceWithNewInstruction(
dot, HloInstruction::CreateReshape(dot->shape(), reduce));
}
// Strength reduce dot(a[1, K], b) =
// reshape(result.shape,
// reduce_sum(
// multiply(broadcast(reshape(a, [K]), {0}), b),
// {0})
// )
// )
if (ShapeUtil::Rank(lhs->shape()) == 1 ||
(ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) {
auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(lhs->shape().element_type(),
{ShapeUtil::ElementsIn(lhs->shape())}),
lhs));
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
computation_->parent(), F32, HloOpcode::kAdd);
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
HloInstruction* reduce;
if (ShapeUtil::Rank(rhs->shape()) == 1) {
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
{0}, add_reduce_computation));
} else {
new_lhs = computation_->AddInstruction(
HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0}));
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(dot->shape().element_type(),
{rhs->shape().dimensions(1)}),
multiply, zero, {0}, add_reduce_computation));
}
changed_ = true;
return computation_->ReplaceWithNewInstruction(
dot, HloInstruction::CreateReshape(dot->shape(), reduce));
}
// Strength reduce dot(a, b[K, 1]) =
// reshape(result.shape,
// reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
// )
if (ShapeUtil::Rank(rhs->shape()) == 1 ||
(ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) {
auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(rhs->shape().element_type(),
{ShapeUtil::ElementsIn(rhs->shape())}),
rhs));
new_rhs = computation_->AddInstruction(
HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1}));
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs));
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
computation_->parent(), F32, HloOpcode::kAdd);
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
ShapeUtil::MakeShape(dot->shape().element_type(),
{lhs->shape().dimensions(0)}),
multiply, zero, {1}, add_reduce_computation));
changed_ = true;
return computation_->ReplaceWithNewInstruction(
dot, HloInstruction::CreateReshape(dot->shape(), reduce));
}
return Status::OK();
}
Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply,
HloInstruction* lhs,
HloInstruction* rhs) {
@ -858,8 +1013,74 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
Status AlgebraicSimplifierVisitor::HandleReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
if (ShapeUtil::HasZeroElements(arg->shape()) ||
ShapeUtil::HasZeroElements(reduce->shape())) {
return computation_->ReplaceWithNewInstruction(
reduce,
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
return Status::OK();
}
// A Transpose feeding a reduce can simply permute the reduction dimensions
// field.
if (arg->opcode() == HloOpcode::kTranspose) {
auto transpose_dimensions = arg->dimensions();
std::vector<int64> new_reduce_dimensions;
for (auto dim : dimensions) {
new_reduce_dimensions.push_back(transpose_dimensions[dim]);
}
return computation_->ReplaceWithNewInstruction(
reduce, HloInstruction::CreateReduce(
reduce->shape(), arg->mutable_operand(0), init_value,
new_reduce_dimensions, function));
}
// A reshape that collapses multiple dimensions into a dimension being reduced
// can just reduce all of those dimensions instead of doing a collapsing
// reshape before a reduction.
if (arg->opcode() == HloOpcode::kReshape) {
std::vector<std::pair<int64, int64>> unmodified_dims =
ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
arg->shape());
std::vector<bool> arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true);
std::vector<bool> arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false);
for (auto dim : dimensions) {
arg_dim_in_output[dim] = false;
}
for (auto dim_pair : unmodified_dims) {
arg_dim_unmodified[dim_pair.second] = true;
}
// The goal is to verify that all dimensions that are not removed in the
// reduce are unmodified by the reshape. For example:
// reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
bool can_move_reshape_into_reduce = true;
for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
can_move_reshape_into_reduce = false;
}
}
if (can_move_reshape_into_reduce) {
changed_ = true;
std::unordered_set<int64> dimensions_not_to_reduce;
for (auto dim_pair : unmodified_dims) {
if (arg_dim_in_output[dim_pair.second]) {
dimensions_not_to_reduce.insert(dim_pair.first);
}
}
std::vector<int64> new_reduce_dimensions;
for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) {
if (dimensions_not_to_reduce.count(i) == 0) {
new_reduce_dimensions.push_back(i);
}
}
return computation_->ReplaceWithNewInstruction(
reduce, HloInstruction::CreateReduce(
reduce->shape(), arg->mutable_operand(0), init_value,
new_reduce_dimensions, function));
}
}
if (ShapeUtil::ElementsIn(reduce->shape()) ==
ShapeUtil::ElementsIn(arg->shape())) {
ShapeUtil::ElementsIn(arg->shape()) ||
ShapeUtil::HasZeroElements(arg->shape())) {
auto reshape = computation_->AddInstruction(
HloInstruction::CreateReshape(reduce->shape(), arg));
changed_ = true;

View File

@ -194,6 +194,7 @@ class HloComputation {
// Set/get the module containing this computation.
void set_parent(HloModule* module) { parent_ = module; }
const HloModule* parent() const { return parent_; }
HloModule* parent() { return parent_; }
// Visit every node in the computation in DFS post-order with the given
// visitor. This is similar to calling HloInstruction::Accept on the root of

View File

@ -67,6 +67,15 @@ XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
}
XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) {
ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR2<float>({{3.0, 4.0}});
auto rhs = builder.ConstantR1<float>({3.0, 4.0});
auto result = builder.Dot(lhs, rhs);
ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_);
}
template <typename Element>
void DotOperationTest::TestOneElementVectorDot() {
ComputationBuilder builder(client_, TestName());

View File

@ -320,6 +320,72 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
ErrorSpec(0.01, 1e-4));
}
XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
const int64 rows = 111, cols = 50;
ComputationBuilder builder(client_, TestName());
Computation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
auto log_ = builder.Log(input);
auto transpose = builder.Transpose(log_, {1, 0});
builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
Array2D<float> input_data(rows, cols);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR2FromArray2D(input_data);
input_literal =
LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1}));
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 colno = 0; colno < cols; ++colno) {
float column_sum = 0;
for (int64 rowno = 0; rowno < rows; ++rowno) {
column_sum += log(input_data(rowno, colno));
}
expected.push_back(column_sum);
}
ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
ErrorSpec(0.01, 1e-4));
}
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
const int64 rows = 111, cols = 50;
ComputationBuilder builder(client_, TestName());
Computation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
auto log_ = builder.Log(input);
auto reshape = builder.Reshape(log_, {rows, cols});
builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
Array3D<float> input_data(rows, 2, cols / 2);
input_data.FillRandom(3.14f, 0.04);
std::unique_ptr<Literal> input_literal =
LiteralUtil::CreateR3FromArray3D(input_data);
std::unique_ptr<GlobalData> input_global_data =
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
std::vector<float> expected;
for (int64 major = 0; major < 2; ++major) {
for (int64 colno = 0; colno < cols / 2; ++colno) {
float column_sum = 0;
for (int64 rowno = 0; rowno < rows; ++rowno) {
column_sum += log(input_data(rowno, major, colno));
}
expected.push_back(column_sum);
}
}
ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
ErrorSpec(0.01, 1e-4));
}
struct BoundsLayout {
std::vector<int64> bounds;
std::vector<int64> layout;