Add support for PopulationCount.
PiperOrigin-RevId: 239066669
This commit is contained in:
parent
c8d3f6f590
commit
73cd0f7199
@ -3167,6 +3167,10 @@ XlaOp Not(const XlaOp& operand) {
|
||||
return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
|
||||
}
|
||||
|
||||
XlaOp PopulationCount(const XlaOp& operand) {
|
||||
return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
|
||||
}
|
||||
|
||||
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
|
||||
|
@ -852,6 +852,7 @@ class XlaBuilder {
|
||||
friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp Not(const XlaOp& operand);
|
||||
friend XlaOp PopulationCount(const XlaOp& operand);
|
||||
friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
friend XlaOp ShiftRightArithmetic(
|
||||
@ -1519,6 +1520,8 @@ XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
|
||||
|
||||
XlaOp Not(const XlaOp& operand);
|
||||
|
||||
XlaOp PopulationCount(const XlaOp& operand);
|
||||
|
||||
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||
absl::Span<const int64> broadcast_dimensions = {});
|
||||
XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
|
||||
|
@ -1237,6 +1237,9 @@ if and only if the corresponding input element is finite.
|
||||
|
||||
<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
|
||||
|
||||
<b>`PopulationCount(operand)`</b> Computes the number of bits set in each
|
||||
element of `operand`.
|
||||
|
||||
<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
|
||||
|
||||
<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
|
||||
|
@ -199,6 +199,9 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleXor(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseBinary(hlo);
|
||||
}
|
||||
virtual Status HandlePopulationCount(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseBinary(hlo);
|
||||
}
|
||||
|
@ -326,6 +326,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
||||
}
|
||||
return Unimplemented("unary op Not is not defined for type '%d'", type);
|
||||
}
|
||||
case HloOpcode::kPopulationCount: {
|
||||
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
|
||||
{operand_value},
|
||||
{operand_value->getType()}, b_);
|
||||
}
|
||||
default:
|
||||
return Unimplemented("unary integer op '%s'",
|
||||
HloOpcodeString(op->opcode()));
|
||||
@ -2219,6 +2224,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kSign:
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
|
||||
|
||||
#include <bitset>
|
||||
#include <cmath>
|
||||
#include <type_traits>
|
||||
|
||||
@ -2482,6 +2483,37 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleClz<ElementwiseT>(clz);
|
||||
}
|
||||
|
||||
// Enable Popcnt only for int32, uint32, int64 and uint64.
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
!(std::is_same<NativeT, uint32>::value ||
|
||||
std::is_same<NativeT, int32>::value ||
|
||||
std::is_same<NativeT, uint64>::value ||
|
||||
std::is_same<NativeT, int64>::value)>::type* = nullptr>
|
||||
Status HandlePopulationCount(HloInstruction* popcnt) {
|
||||
return UnsupportedTypeError(popcnt);
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<
|
||||
std::is_same<NativeT, uint32>::value ||
|
||||
std::is_same<NativeT, int32>::value ||
|
||||
std::is_same<NativeT, uint64>::value ||
|
||||
std::is_same<NativeT, int64>::value>::type* = nullptr>
|
||||
Status HandlePopulationCount(HloInstruction* popcnt) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[popcnt],
|
||||
ElementWiseUnaryOp(popcnt, [](ElementwiseT elem_operand) {
|
||||
return std::bitset<CHAR_BIT * sizeof elem_operand>(elem_operand)
|
||||
.count();
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HandlePopulationCount(HloInstruction* popcnt) override {
|
||||
return HandlePopulationCount<ElementwiseT>(popcnt);
|
||||
}
|
||||
|
||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||
NativeT>::value>::type* = nullptr>
|
||||
Status HandleSin(HloInstruction* sin) {
|
||||
|
@ -947,6 +947,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kPower:
|
||||
|
@ -703,6 +703,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kSign:
|
||||
@ -1408,6 +1409,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kSign:
|
||||
@ -1754,6 +1756,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRemainder:
|
||||
@ -2135,6 +2138,7 @@ bool HloInstruction::IsElementwiseImpl(
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kRsqrt:
|
||||
@ -2600,6 +2604,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleIsFinite(this);
|
||||
case HloOpcode::kNot:
|
||||
return visitor->HandleNot(this);
|
||||
case HloOpcode::kPopulationCount:
|
||||
return visitor->HandlePopulationCount(this);
|
||||
case HloOpcode::kBitcast:
|
||||
return visitor->HandleBitcast(this);
|
||||
case HloOpcode::kBroadcast:
|
||||
|
@ -108,6 +108,7 @@ namespace xla {
|
||||
V(kOutfeed, "outfeed", 2) \
|
||||
V(kPad, "pad", 2) \
|
||||
V(kParameter, "parameter", 0) \
|
||||
V(kPopulationCount, "popcnt", 1) \
|
||||
V(kPower, "power", 2) \
|
||||
V(kReal, "real", 1) \
|
||||
V(kRecv, "recv", 1) \
|
||||
|
@ -744,6 +744,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kSign:
|
||||
|
@ -88,6 +88,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kReplicaId:
|
||||
|
@ -2016,6 +2016,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kXor:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kReducePrecision:
|
||||
|
@ -308,6 +308,14 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
|
||||
}
|
||||
return shape;
|
||||
case HloOpcode::kPopulationCount:
|
||||
if (!ShapeUtil::ElementIsIntegral(shape)) {
|
||||
return InvalidArgument(
|
||||
"Expected an integral element type in argument to PopulationCount "
|
||||
"operation; got %s.",
|
||||
PrimitiveType_Name(shape.element_type()));
|
||||
}
|
||||
return shape;
|
||||
case HloOpcode::kSign:
|
||||
if (!ShapeUtil::ElementIsSigned(shape) &&
|
||||
!ShapeUtil::ElementIsComplex(shape)) {
|
||||
|
@ -1065,6 +1065,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
|
||||
ComputeAndCompareR1<uint32>(&builder, {}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto a = ConstantR1<int32>(&builder, {0, 1, -15, 341});
|
||||
PopulationCount(a);
|
||||
ComputeAndCompareR1<int32>(&builder, {0, 1, 29, 5}, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto a = ConstantR2<int32>(&builder, {{0, 1}, {-15, 341}});
|
||||
PopulationCount(a);
|
||||
Array2D<int32> expected_array({{0, 1}, {29, 5}});
|
||||
ComputeAndCompareR2<int32>(&builder, expected_array, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto a = ConstantR2<int64>(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}});
|
||||
PopulationCount(a);
|
||||
Array2D<int64> expected_array({{0, 64}, {63, 62}});
|
||||
ComputeAndCompareR2<int64>(&builder, expected_array, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
|
||||
XlaBuilder builder(TestName());
|
||||
auto a = ConstantR1<int32>(
|
||||
|
Loading…
Reference in New Issue
Block a user