Implement Mul(Convert(Pred), operand) => select(pred, operand, 0) optimization.
PiperOrigin-RevId: 316739811 Change-Id: Ica5e50c6639a9792ae1dd47eefd713021fb97533
This commit is contained in:
parent
e0266dbf39
commit
ed1d7d09ae
@ -2455,6 +2455,25 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
HloInstruction *convert_operand, *operand;
|
||||||
|
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
|
||||||
|
if (Match(multiply,
|
||||||
|
m::MultiplyAnyOrder(
|
||||||
|
m::Op(&operand),
|
||||||
|
m::Convert(
|
||||||
|
m::Op(&convert_operand)
|
||||||
|
.WithShape(m::Shape().WithElementType(PRED)))))) {
|
||||||
|
HloInstruction* zero_like_multiply =
|
||||||
|
BroadcastZeros(computation_, multiply->shape().element_type(),
|
||||||
|
multiply->shape().dimensions());
|
||||||
|
return ReplaceWithNewInstruction(
|
||||||
|
multiply, HloInstruction::CreateTernary(
|
||||||
|
multiply->shape(), HloOpcode::kSelect, convert_operand,
|
||||||
|
operand, zero_like_multiply));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
|
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
|
||||||
HloInstruction *a, *c1, *c2;
|
HloInstruction *a, *c1, *c2;
|
||||||
if (Match(multiply,
|
if (Match(multiply,
|
||||||
|
@ -539,6 +539,15 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
|
|||||||
/*result_shape_bounds=*/broadcast_dimensions);
|
/*result_shape_bounds=*/broadcast_dimensions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloInstruction* BroadcastOnes(HloComputation* computation,
|
||||||
|
PrimitiveType element_type,
|
||||||
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
|
HloInstruction* one = computation->AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
|
||||||
|
return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
|
||||||
|
/*result_shape_bounds=*/broadcast_dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
|
// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
|
||||||
// while internal nodes are tuples.
|
// while internal nodes are tuples.
|
||||||
HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
|
HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
|
||||||
|
@ -276,6 +276,11 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
|
|||||||
PrimitiveType element_type,
|
PrimitiveType element_type,
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
|
|
||||||
|
// Same as above, but fill the tensor with ones.
|
||||||
|
HloInstruction* BroadcastOnes(HloComputation* computation,
|
||||||
|
PrimitiveType element_type,
|
||||||
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
|
|
||||||
// Creates a HLO computation that takes arguments of type `domain` and produces
|
// Creates a HLO computation that takes arguments of type `domain` and produces
|
||||||
// a value of type `range`.
|
// a value of type `range`.
|
||||||
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
|
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
|
||||||
|
Loading…
Reference in New Issue
Block a user