Implement Mul(Convert(Pred), operand) => select(pred, operand, 0) optimization.

PiperOrigin-RevId: 316739811
Change-Id: Ica5e50c6639a9792ae1dd47eefd713021fb97533
This commit is contained in:
Yunxing Dai 2020-06-16 12:54:54 -07:00 committed by TensorFlower Gardener
parent e0266dbf39
commit ed1d7d09ae
3 changed files with 33 additions and 0 deletions

View File

@ -2455,6 +2455,25 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK();
}
{
HloInstruction *convert_operand, *operand;
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
if (Match(multiply,
m::MultiplyAnyOrder(
m::Op(&operand),
m::Convert(
m::Op(&convert_operand)
.WithShape(m::Shape().WithElementType(PRED)))))) {
HloInstruction* zero_like_multiply =
BroadcastZeros(computation_, multiply->shape().element_type(),
multiply->shape().dimensions());
return ReplaceWithNewInstruction(
multiply, HloInstruction::CreateTernary(
multiply->shape(), HloOpcode::kSelect, convert_operand,
operand, zero_like_multiply));
}
}
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
HloInstruction *a, *c1, *c2;
if (Match(multiply,

View File

@ -539,6 +539,15 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
/*result_shape_bounds=*/broadcast_dimensions);
}
HloInstruction* BroadcastOnes(HloComputation* computation,
PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions) {
HloInstruction* one = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
// while internal nodes are tuples.
HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {

View File

@ -276,6 +276,11 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions);
// Same as above, but fill the tensor with ones.
HloInstruction* BroadcastOnes(HloComputation* computation,
PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions);
// Creates a HLO computation that takes arguments of type `domain` and produces
// a value of type `range`.
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(