Implement Mul(Convert(Pred), operand) => select(pred, operand, 0) optimization.
PiperOrigin-RevId: 316739811 Change-Id: Ica5e50c6639a9792ae1dd47eefd713021fb97533
This commit is contained in:
parent
e0266dbf39
commit
ed1d7d09ae
@ -2455,6 +2455,25 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
{
|
||||
HloInstruction *convert_operand, *operand;
|
||||
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
|
||||
if (Match(multiply,
|
||||
m::MultiplyAnyOrder(
|
||||
m::Op(&operand),
|
||||
m::Convert(
|
||||
m::Op(&convert_operand)
|
||||
.WithShape(m::Shape().WithElementType(PRED)))))) {
|
||||
HloInstruction* zero_like_multiply =
|
||||
BroadcastZeros(computation_, multiply->shape().element_type(),
|
||||
multiply->shape().dimensions());
|
||||
return ReplaceWithNewInstruction(
|
||||
multiply, HloInstruction::CreateTernary(
|
||||
multiply->shape(), HloOpcode::kSelect, convert_operand,
|
||||
operand, zero_like_multiply));
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
|
||||
HloInstruction *a, *c1, *c2;
|
||||
if (Match(multiply,
|
||||
|
@ -539,6 +539,15 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
|
||||
/*result_shape_bounds=*/broadcast_dimensions);
|
||||
}
|
||||
|
||||
HloInstruction* BroadcastOnes(HloComputation* computation,
|
||||
PrimitiveType element_type,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
HloInstruction* one = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
|
||||
return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
|
||||
/*result_shape_bounds=*/broadcast_dimensions);
|
||||
}
|
||||
|
||||
// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
|
||||
// while internal nodes are tuples.
|
||||
HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
|
||||
|
@ -276,6 +276,11 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
|
||||
PrimitiveType element_type,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
|
||||
// Same as above, but fill the tensor with ones.
|
||||
HloInstruction* BroadcastOnes(HloComputation* computation,
|
||||
PrimitiveType element_type,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
|
||||
// Creates a HLO computation that takes arguments of type `domain` and produces
|
||||
// a value of type `range`.
|
||||
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
|
||||
|
Loading…
Reference in New Issue
Block a user