diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 2f574366694..570f24ba7ef 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3167,6 +3167,10 @@ XlaOp Not(const XlaOp& operand) { return operand.builder()->UnaryOp(HloOpcode::kNot, operand); } +XlaOp PopulationCount(const XlaOp& operand) { + return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand); +} + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions) { return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 80f93a8b6de..6233c7ab166 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -852,6 +852,7 @@ class XlaBuilder { friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); friend XlaOp Not(const XlaOp& operand); + friend XlaOp PopulationCount(const XlaOp& operand); friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions); friend XlaOp ShiftRightArithmetic( @@ -1519,6 +1520,8 @@ XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, XlaOp Not(const XlaOp& operand); +XlaOp PopulationCount(const XlaOp& operand); + XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, absl::Span broadcast_dimensions = {}); XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 7d718c53010..7b6edc7ff4f 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1237,6 +1237,9 @@ if and only if the corresponding input element is finite. `LogicalNot(operand)` Element-wise logical not `x -> !(x)`. +`PopulationCount(operand)` Computes the number of bits set in each +element of `operand`. + `Neg(operand)` Element-wise negation `x -> -x`. `Sign(operand)` Element-wise sign operation `x -> sgn(x)` where diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 246f2af09b5..9e05155bce5 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -199,6 +199,9 @@ class DfsHloVisitorBase { virtual Status HandleXor(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } + virtual Status HandlePopulationCount(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleShiftLeft(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); } diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 53513fa5226..d10e37a2003 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -326,6 +326,11 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } return Unimplemented("unary op Not is not defined for type '%d'", type); } + case HloOpcode::kPopulationCount: { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop, + {operand_value}, + {operand_value->getType()}, b_); + } default: return Unimplemented("unary integer op '%s'", HloOpcodeString(op->opcode())); @@ -2219,6 +2224,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kLog1p: case HloOpcode::kNegate: case HloOpcode::kNot: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 2d8a578985e..7cbf880690a 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ +#include #include #include @@ -2482,6 +2483,37 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleClz(clz); } + // Enable Popcnt only for int32, uint32, int64 and uint64. + template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value)>::type* = nullptr> + Status HandlePopulationCount(HloInstruction* popcnt) { + return UnsupportedTypeError(popcnt); + } + + template ::value || + std::is_same::value || + std::is_same::value || + std::is_same::value>::type* = nullptr> + Status HandlePopulationCount(HloInstruction* popcnt) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[popcnt], + ElementWiseUnaryOp(popcnt, [](ElementwiseT elem_operand) { + return std::bitset(elem_operand) + .count(); + })); + return Status::OK(); + } + + Status HandlePopulationCount(HloInstruction* popcnt) override { + return HandlePopulationCount(popcnt); + } + template ::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 116b32f5f4c..70ba0b1faa3 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -947,6 +947,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kMultiply: case HloOpcode::kNegate: case HloOpcode::kNot: + case HloOpcode::kPopulationCount: case HloOpcode::kOr: case HloOpcode::kXor: case HloOpcode::kPower: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index fe8a178f80f..d8d473f271f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -703,6 +703,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: @@ -1408,6 +1409,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: @@ -1754,6 +1756,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kMinimum: case HloOpcode::kMultiply: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kRemainder: @@ -2135,6 +2138,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kRsqrt: @@ -2600,6 +2604,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleIsFinite(this); case HloOpcode::kNot: return visitor->HandleNot(this); + case HloOpcode::kPopulationCount: + return visitor->HandlePopulationCount(this); case HloOpcode::kBitcast: return visitor->HandleBitcast(this); case HloOpcode::kBroadcast: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index c5ccd49552a..6d3a49898c2 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -108,6 +108,7 @@ namespace xla { V(kOutfeed, "outfeed", 2) \ V(kPad, "pad", 2) \ V(kParameter, "parameter", 0) \ + V(kPopulationCount, "popcnt", 1) \ V(kPower, "power", 2) \ V(kReal, "real", 1) \ V(kRecv, "recv", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 8e76a1f262e..ce412912a66 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -744,6 +744,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kLog1p: case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: case HloOpcode::kSign: diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 4868cf961aa..a30933889f2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -88,6 +88,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kXor: case HloOpcode::kOutfeed: case HloOpcode::kPad: + case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kReplicaId: diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 039954a1837..20577514935 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2016,6 +2016,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kOr: case HloOpcode::kXor: case HloOpcode::kPad: + case HloOpcode::kPopulationCount: case HloOpcode::kPower: case HloOpcode::kReal: case HloOpcode::kReducePrecision: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index e1536684c06..533d3a940af 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -308,6 +308,14 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); } return shape; + case HloOpcode::kPopulationCount: + if (!ShapeUtil::ElementIsIntegral(shape)) { + return InvalidArgument( + "Expected an integral element type in argument to PopulationCount " + "operation; got %s.", + PrimitiveType_Name(shape.element_type())); + } + return shape; case HloOpcode::kSign: if (!ShapeUtil::ElementIsSigned(shape) && !ShapeUtil::ElementIsComplex(shape)) { diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 21458b40b10..a5e27cd67a7 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1065,6 +1065,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) { + XlaBuilder builder(TestName()); + auto a = ConstantR1(&builder, {0, 1, -15, 341}); + PopulationCount(a); + ComputeAndCompareR1(&builder, {0, 1, 29, 5}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, 1}, {-15, 341}}); + PopulationCount(a); + Array2D expected_array({{0, 1}, {29, 5}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) { + XlaBuilder builder(TestName()); + auto a = ConstantR2(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}}); + PopulationCount(a); + Array2D expected_array({{0, 64}, {63, 62}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XlaBuilder builder(TestName()); auto a = ConstantR1(