Strength reduce Dot into broadcasting multiply and reduce. Also optimizes
transposes and reshapes that feed reductions. Change: 151162327
This commit is contained in:
parent
fdf32bc5ed
commit
45dbb0a02d
tensorflow/compiler/xla
@ -76,6 +76,24 @@ bool ReshapeIsBitcast(
|
||||
return ShapeUtil::ReshapeIsBitcast(operand->shape(), reshape->shape()) &&
|
||||
valid_bitcast_callback(operand->shape(), reshape->shape());
|
||||
}
|
||||
|
||||
// Adds a scalar computation to the module to enable optimizations with dot
|
||||
// converting into reduction.
|
||||
HloComputation* CreateScalarBinaryComputation(HloModule* module,
|
||||
PrimitiveType primitive_type,
|
||||
HloOpcode opcode) {
|
||||
HloComputation::Builder b("scalar computation");
|
||||
auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "scalar lhs"));
|
||||
auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {}), "scalar rhs"));
|
||||
auto scalar_op = b.AddInstruction(
|
||||
HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
|
||||
opcode, scalar_lhs, scalar_rhs));
|
||||
HloComputation* scalar_computation =
|
||||
module->AddEmbeddedComputation(b.Build(scalar_op));
|
||||
return scalar_computation;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// AlgebraicSimplifierVisitor traverses the HLO computation and reduces certain
|
||||
@ -105,6 +123,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {
|
||||
Status HandleDivide(HloInstruction* divide, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override;
|
||||
|
||||
Status HandleDot(HloInstruction* dot, HloInstruction* lhs,
|
||||
HloInstruction* rhs) override;
|
||||
|
||||
Status HandleGetTupleElement(HloInstruction* get_tuple_element,
|
||||
HloInstruction* operand) override;
|
||||
|
||||
@ -304,6 +325,140 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
// Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or
|
||||
// below.
|
||||
if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 ||
|
||||
ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Replace a zero element dot with a broadcast of the constant 0.
|
||||
if (ShapeUtil::HasZeroElements(dot->shape()) ||
|
||||
ShapeUtil::HasZeroElements(lhs->shape()) ||
|
||||
ShapeUtil::HasZeroElements(rhs->shape())) {
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
changed_ = true;
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {}));
|
||||
}
|
||||
|
||||
// Simplify dot(transpose(a), transpose(b)) to tranpose(dot(b,a)).
|
||||
if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
|
||||
auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot,
|
||||
rhs->mutable_operand(0), lhs->mutable_operand(0)));
|
||||
changed_ = true;
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
|
||||
}
|
||||
|
||||
// Simplify outer product into multiply with implicit broadcasting.
|
||||
//
|
||||
// A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N])
|
||||
if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) {
|
||||
changed_ = true;
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply,
|
||||
lhs, rhs));
|
||||
}
|
||||
|
||||
// The following graph transformations take Dots where at least one input is a
|
||||
// vector or has a degenerate dimension and converts it into a multiply and
|
||||
// reduce. This should enable more fusion than leaving the nodes as Dot
|
||||
// operations.
|
||||
|
||||
// Strength reduce dot(a[K] , b[K]) =
|
||||
// reshape(result.shape,
|
||||
// reduce_sum(multiply(a, b), {0}))
|
||||
if (ShapeUtil::Rank(rhs->shape()) == 1 &&
|
||||
ShapeUtil::Rank(lhs->shape()) == 1) {
|
||||
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
rhs->shape(), HloOpcode::kMultiply, lhs, rhs));
|
||||
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
|
||||
computation_->parent(), F32, HloOpcode::kAdd);
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
|
||||
{0}, add_reduce_computation));
|
||||
changed_ = true;
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateReshape(dot->shape(), reduce));
|
||||
}
|
||||
|
||||
// Strength reduce dot(a[1, K], b) =
|
||||
// reshape(result.shape,
|
||||
// reduce_sum(
|
||||
// multiply(broadcast(reshape(a, [K]), {0}), b),
|
||||
// {0})
|
||||
// )
|
||||
// )
|
||||
if (ShapeUtil::Rank(lhs->shape()) == 1 ||
|
||||
(ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) {
|
||||
auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(lhs->shape().element_type(),
|
||||
{ShapeUtil::ElementsIn(lhs->shape())}),
|
||||
lhs));
|
||||
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
|
||||
computation_->parent(), F32, HloOpcode::kAdd);
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
HloInstruction* reduce;
|
||||
if (ShapeUtil::Rank(rhs->shape()) == 1) {
|
||||
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
|
||||
reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero,
|
||||
{0}, add_reduce_computation));
|
||||
} else {
|
||||
new_lhs = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0}));
|
||||
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs));
|
||||
|
||||
reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(dot->shape().element_type(),
|
||||
{rhs->shape().dimensions(1)}),
|
||||
multiply, zero, {0}, add_reduce_computation));
|
||||
}
|
||||
changed_ = true;
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateReshape(dot->shape(), reduce));
|
||||
}
|
||||
|
||||
// Strength reduce dot(a, b[K, 1]) =
|
||||
// reshape(result.shape,
|
||||
// reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0})
|
||||
// )
|
||||
if (ShapeUtil::Rank(rhs->shape()) == 1 ||
|
||||
(ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) {
|
||||
auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(rhs->shape().element_type(),
|
||||
{ShapeUtil::ElementsIn(rhs->shape())}),
|
||||
rhs));
|
||||
new_rhs = computation_->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1}));
|
||||
auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary(
|
||||
lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs));
|
||||
HloComputation* add_reduce_computation = CreateScalarBinaryComputation(
|
||||
computation_->parent(), F32, HloOpcode::kAdd);
|
||||
auto zero = computation_->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)));
|
||||
auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce(
|
||||
ShapeUtil::MakeShape(dot->shape().element_type(),
|
||||
{lhs->shape().dimensions(0)}),
|
||||
multiply, zero, {1}, add_reduce_computation));
|
||||
changed_ = true;
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
dot, HloInstruction::CreateReshape(dot->shape(), reduce));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply,
|
||||
HloInstruction* lhs,
|
||||
HloInstruction* rhs) {
|
||||
@ -858,8 +1013,74 @@ Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice,
|
||||
Status AlgebraicSimplifierVisitor::HandleReduce(
|
||||
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
|
||||
tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function) {
|
||||
if (ShapeUtil::HasZeroElements(arg->shape()) ||
|
||||
ShapeUtil::HasZeroElements(reduce->shape())) {
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
reduce,
|
||||
HloInstruction::CreateBroadcast(reduce->shape(), init_value, {}));
|
||||
return Status::OK();
|
||||
}
|
||||
// A Transpose feeding a reduce can simply permute the reduction dimensions
|
||||
// field.
|
||||
if (arg->opcode() == HloOpcode::kTranspose) {
|
||||
auto transpose_dimensions = arg->dimensions();
|
||||
std::vector<int64> new_reduce_dimensions;
|
||||
for (auto dim : dimensions) {
|
||||
new_reduce_dimensions.push_back(transpose_dimensions[dim]);
|
||||
}
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
reduce, HloInstruction::CreateReduce(
|
||||
reduce->shape(), arg->mutable_operand(0), init_value,
|
||||
new_reduce_dimensions, function));
|
||||
}
|
||||
|
||||
// A reshape that collapses multiple dimensions into a dimension being reduced
|
||||
// can just reduce all of those dimensions instead of doing a collapsing
|
||||
// reshape before a reduction.
|
||||
if (arg->opcode() == HloOpcode::kReshape) {
|
||||
std::vector<std::pair<int64, int64>> unmodified_dims =
|
||||
ShapeUtil::DimensionsUnmodifiedByReshape(arg->operand(0)->shape(),
|
||||
arg->shape());
|
||||
std::vector<bool> arg_dim_in_output(ShapeUtil::Rank(arg->shape()), true);
|
||||
std::vector<bool> arg_dim_unmodified(ShapeUtil::Rank(arg->shape()), false);
|
||||
for (auto dim : dimensions) {
|
||||
arg_dim_in_output[dim] = false;
|
||||
}
|
||||
for (auto dim_pair : unmodified_dims) {
|
||||
arg_dim_unmodified[dim_pair.second] = true;
|
||||
}
|
||||
// The goal is to verify that all dimensions that are not removed in the
|
||||
// reduce are unmodified by the reshape. For example:
|
||||
// reduce(reshape([A,B*C], a[A,B,C]),[1]) = reduce(a[A, B, C], [1, 2])
|
||||
bool can_move_reshape_into_reduce = true;
|
||||
for (int64 i = 0; i < arg_dim_in_output.size(); ++i) {
|
||||
if (arg_dim_in_output[i] && !arg_dim_unmodified[i]) {
|
||||
can_move_reshape_into_reduce = false;
|
||||
}
|
||||
}
|
||||
if (can_move_reshape_into_reduce) {
|
||||
changed_ = true;
|
||||
std::unordered_set<int64> dimensions_not_to_reduce;
|
||||
for (auto dim_pair : unmodified_dims) {
|
||||
if (arg_dim_in_output[dim_pair.second]) {
|
||||
dimensions_not_to_reduce.insert(dim_pair.first);
|
||||
}
|
||||
}
|
||||
std::vector<int64> new_reduce_dimensions;
|
||||
for (int64 i = 0; i < ShapeUtil::Rank(arg->operand(0)->shape()); ++i) {
|
||||
if (dimensions_not_to_reduce.count(i) == 0) {
|
||||
new_reduce_dimensions.push_back(i);
|
||||
}
|
||||
}
|
||||
return computation_->ReplaceWithNewInstruction(
|
||||
reduce, HloInstruction::CreateReduce(
|
||||
reduce->shape(), arg->mutable_operand(0), init_value,
|
||||
new_reduce_dimensions, function));
|
||||
}
|
||||
}
|
||||
if (ShapeUtil::ElementsIn(reduce->shape()) ==
|
||||
ShapeUtil::ElementsIn(arg->shape())) {
|
||||
ShapeUtil::ElementsIn(arg->shape()) ||
|
||||
ShapeUtil::HasZeroElements(arg->shape())) {
|
||||
auto reshape = computation_->AddInstruction(
|
||||
HloInstruction::CreateReshape(reduce->shape(), arg));
|
||||
changed_ = true;
|
||||
|
@ -194,6 +194,7 @@ class HloComputation {
|
||||
// Set/get the module containing this computation.
|
||||
void set_parent(HloModule* module) { parent_ = module; }
|
||||
const HloModule* parent() const { return parent_; }
|
||||
HloModule* parent() { return parent_; }
|
||||
|
||||
// Visit every node in the computation in DFS post-order with the given
|
||||
// visitor. This is similar to calling HloInstruction::Accept on the root of
|
||||
|
@ -67,6 +67,15 @@ XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) {
|
||||
ComputeAndCompareR0<float>(&builder, 0.0, {}, error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(DotOperationTest, TrivialMatrixVectorDotF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto lhs = builder.ConstantR2<float>({{3.0, 4.0}});
|
||||
auto rhs = builder.ConstantR1<float>({3.0, 4.0});
|
||||
auto result = builder.Dot(lhs, rhs);
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, {25.0}, {}, error_spec_);
|
||||
}
|
||||
|
||||
template <typename Element>
|
||||
void DotOperationTest::TestOneElementVectorDot() {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
@ -320,6 +320,72 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
|
||||
ErrorSpec(0.01, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
|
||||
const int64 rows = 111, cols = 50;
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Computation add_f32 = CreateScalarAddComputation(F32, &builder);
|
||||
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto zero = builder.ConstantR0<float>(0.0);
|
||||
auto log_ = builder.Log(input);
|
||||
auto transpose = builder.Transpose(log_, {1, 0});
|
||||
builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{1});
|
||||
|
||||
Array2D<float> input_data(rows, cols);
|
||||
input_data.FillRandom(3.14f, 0.04);
|
||||
std::unique_ptr<Literal> input_literal =
|
||||
LiteralUtil::CreateR2FromArray2D(input_data);
|
||||
input_literal =
|
||||
LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1}));
|
||||
std::unique_ptr<GlobalData> input_global_data =
|
||||
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> expected;
|
||||
for (int64 colno = 0; colno < cols; ++colno) {
|
||||
float column_sum = 0;
|
||||
for (int64 rowno = 0; rowno < rows; ++rowno) {
|
||||
column_sum += log(input_data(rowno, colno));
|
||||
}
|
||||
expected.push_back(column_sum);
|
||||
}
|
||||
ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
|
||||
ErrorSpec(0.01, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
|
||||
const int64 rows = 111, cols = 50;
|
||||
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
Computation add_f32 = CreateScalarAddComputation(F32, &builder);
|
||||
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
|
||||
auto input = builder.Parameter(0, input_shape, "input");
|
||||
auto zero = builder.ConstantR0<float>(0.0);
|
||||
auto log_ = builder.Log(input);
|
||||
auto reshape = builder.Reshape(log_, {rows, cols});
|
||||
builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0});
|
||||
|
||||
Array3D<float> input_data(rows, 2, cols / 2);
|
||||
input_data.FillRandom(3.14f, 0.04);
|
||||
std::unique_ptr<Literal> input_literal =
|
||||
LiteralUtil::CreateR3FromArray3D(input_data);
|
||||
std::unique_ptr<GlobalData> input_global_data =
|
||||
client_->TransferToServer(*input_literal).ConsumeValueOrDie();
|
||||
|
||||
std::vector<float> expected;
|
||||
for (int64 major = 0; major < 2; ++major) {
|
||||
for (int64 colno = 0; colno < cols / 2; ++colno) {
|
||||
float column_sum = 0;
|
||||
for (int64 rowno = 0; rowno < rows; ++rowno) {
|
||||
column_sum += log(input_data(rowno, major, colno));
|
||||
}
|
||||
expected.push_back(column_sum);
|
||||
}
|
||||
}
|
||||
ComputeAndCompareR1<float>(&builder, expected, {input_global_data.get()},
|
||||
ErrorSpec(0.01, 1e-4));
|
||||
}
|
||||
|
||||
struct BoundsLayout {
|
||||
std::vector<int64> bounds;
|
||||
std::vector<int64> layout;
|
||||
|
Loading…
Reference in New Issue
Block a user