From ed1d7d09aec54c8c277da957ed18d17ed6885711 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Tue, 16 Jun 2020 12:54:54 -0700 Subject: [PATCH] Implement Mul(Convert(Pred), operand) => select(pred, operand, 0) optimization. PiperOrigin-RevId: 316739811 Change-Id: Ica5e50c6639a9792ae1dd47eefd713021fb97533 --- .../xla/service/algebraic_simplifier.cc | 19 +++++++++++++++++++ .../xla/service/hlo_creation_utils.cc | 9 +++++++++ .../compiler/xla/service/hlo_creation_utils.h | 5 +++++ 3 files changed, 33 insertions(+) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 98e3229b062..ce2a801fccd 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2455,6 +2455,25 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { return Status::OK(); } + { + HloInstruction *convert_operand, *operand; + // Mul(Convert(Pred), operand) => select(pred, operand, 0) + if (Match(multiply, + m::MultiplyAnyOrder( + m::Op(&operand), + m::Convert( + m::Op(&convert_operand) + .WithShape(m::Shape().WithElementType(PRED)))))) { + HloInstruction* zero_like_multiply = + BroadcastZeros(computation_, multiply->shape().element_type(), + multiply->shape().dimensions()); + return ReplaceWithNewInstruction( + multiply, HloInstruction::CreateTernary( + multiply->shape(), HloOpcode::kSelect, convert_operand, + operand, zero_like_multiply)); + } + } + VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]"; HloInstruction *a, *c1, *c2; if (Match(multiply, diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc index dd174772c62..0f5267e9fbc 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc @@ -539,6 +539,15 @@ HloInstruction* BroadcastZeros(HloComputation* computation, /*result_shape_bounds=*/broadcast_dimensions); } +HloInstruction* BroadcastOnes(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions) { + HloInstruction* one = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::One(element_type))); + return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{}, + /*result_shape_bounds=*/broadcast_dimensions); +} + // Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero // while internal nodes are tuples. HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h index 3f2e3aa25a1..2ba753d3cdb 100644 --- a/tensorflow/compiler/xla/service/hlo_creation_utils.h +++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h @@ -276,6 +276,11 @@ HloInstruction* BroadcastZeros(HloComputation* computation, PrimitiveType element_type, absl::Span broadcast_dimensions); +// Same as above, but fill the tensor with ones. +HloInstruction* BroadcastOnes(HloComputation* computation, + PrimitiveType element_type, + absl::Span broadcast_dimensions); + // Creates a HLO computation that takes arguments of type `domain` and produces // a value of type `range`. StatusOr> CreateComputationWithSignature(