Add support for PopulationCount.
PiperOrigin-RevId: 239066669
This commit is contained in:
parent
c8d3f6f590
commit
73cd0f7199
@ -3167,6 +3167,10 @@ XlaOp Not(const XlaOp& operand) {
|
|||||||
return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
|
return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp PopulationCount(const XlaOp& operand) {
|
||||||
|
return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions) {
|
absl::Span<const int64> broadcast_dimensions) {
|
||||||
return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
|
return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
|
||||||
|
@ -852,6 +852,7 @@ class XlaBuilder {
|
|||||||
friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
|
friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
friend XlaOp Not(const XlaOp& operand);
|
friend XlaOp Not(const XlaOp& operand);
|
||||||
|
friend XlaOp PopulationCount(const XlaOp& operand);
|
||||||
friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions);
|
absl::Span<const int64> broadcast_dimensions);
|
||||||
friend XlaOp ShiftRightArithmetic(
|
friend XlaOp ShiftRightArithmetic(
|
||||||
@ -1519,6 +1520,8 @@ XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
|
|||||||
|
|
||||||
XlaOp Not(const XlaOp& operand);
|
XlaOp Not(const XlaOp& operand);
|
||||||
|
|
||||||
|
XlaOp PopulationCount(const XlaOp& operand);
|
||||||
|
|
||||||
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
absl::Span<const int64> broadcast_dimensions = {});
|
absl::Span<const int64> broadcast_dimensions = {});
|
||||||
XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
|
XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
|
||||||
|
@ -1237,6 +1237,9 @@ if and only if the corresponding input element is finite.
|
|||||||
|
|
||||||
<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
|
<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
|
||||||
|
|
||||||
|
<b>`PopulationCount(operand)`</b> Computes the number of bits set in each
|
||||||
|
element of `operand`.
|
||||||
|
|
||||||
<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
|
<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
|
||||||
|
|
||||||
<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
|
<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
|
||||||
|
@ -199,6 +199,9 @@ class DfsHloVisitorBase {
|
|||||||
virtual Status HandleXor(HloInstructionPtr hlo) {
|
virtual Status HandleXor(HloInstructionPtr hlo) {
|
||||||
return HandleElementwiseBinary(hlo);
|
return HandleElementwiseBinary(hlo);
|
||||||
}
|
}
|
||||||
|
virtual Status HandlePopulationCount(HloInstructionPtr hlo) {
|
||||||
|
return HandleElementwiseUnary(hlo);
|
||||||
|
}
|
||||||
virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
|
virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
|
||||||
return HandleElementwiseBinary(hlo);
|
return HandleElementwiseBinary(hlo);
|
||||||
}
|
}
|
||||||
|
@ -326,6 +326,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
|
|||||||
}
|
}
|
||||||
return Unimplemented("unary op Not is not defined for type '%d'", type);
|
return Unimplemented("unary op Not is not defined for type '%d'", type);
|
||||||
}
|
}
|
||||||
|
case HloOpcode::kPopulationCount: {
|
||||||
|
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
|
||||||
|
{operand_value},
|
||||||
|
{operand_value->getType()}, b_);
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return Unimplemented("unary integer op '%s'",
|
return Unimplemented("unary integer op '%s'",
|
||||||
HloOpcodeString(op->opcode()));
|
HloOpcodeString(op->opcode()));
|
||||||
@ -2219,6 +2224,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
|
|||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kRsqrt:
|
case HloOpcode::kRsqrt:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
|
||||||
|
|
||||||
|
#include <bitset>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
|
|
||||||
@ -2482,6 +2483,37 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
|||||||
return HandleClz<ElementwiseT>(clz);
|
return HandleClz<ElementwiseT>(clz);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enable Popcnt only for int32, uint32, int64 and uint64.
|
||||||
|
template <typename NativeT,
|
||||||
|
typename std::enable_if<
|
||||||
|
!(std::is_same<NativeT, uint32>::value ||
|
||||||
|
std::is_same<NativeT, int32>::value ||
|
||||||
|
std::is_same<NativeT, uint64>::value ||
|
||||||
|
std::is_same<NativeT, int64>::value)>::type* = nullptr>
|
||||||
|
Status HandlePopulationCount(HloInstruction* popcnt) {
|
||||||
|
return UnsupportedTypeError(popcnt);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename NativeT,
|
||||||
|
typename std::enable_if<
|
||||||
|
std::is_same<NativeT, uint32>::value ||
|
||||||
|
std::is_same<NativeT, int32>::value ||
|
||||||
|
std::is_same<NativeT, uint64>::value ||
|
||||||
|
std::is_same<NativeT, int64>::value>::type* = nullptr>
|
||||||
|
Status HandlePopulationCount(HloInstruction* popcnt) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
parent_->evaluated_[popcnt],
|
||||||
|
ElementWiseUnaryOp(popcnt, [](ElementwiseT elem_operand) {
|
||||||
|
return std::bitset<CHAR_BIT * sizeof elem_operand>(elem_operand)
|
||||||
|
.count();
|
||||||
|
}));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HandlePopulationCount(HloInstruction* popcnt) override {
|
||||||
|
return HandlePopulationCount<ElementwiseT>(popcnt);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
template <typename NativeT, typename std::enable_if<std::is_floating_point<
|
||||||
NativeT>::value>::type* = nullptr>
|
NativeT>::value>::type* = nullptr>
|
||||||
Status HandleSin(HloInstruction* sin) {
|
Status HandleSin(HloInstruction* sin) {
|
||||||
|
@ -947,6 +947,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
|||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
case HloOpcode::kXor:
|
case HloOpcode::kXor:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
|
@ -703,6 +703,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
|
|||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kRsqrt:
|
case HloOpcode::kRsqrt:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
@ -1408,6 +1409,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
|||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kRsqrt:
|
case HloOpcode::kRsqrt:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
@ -1754,6 +1756,7 @@ bool HloInstruction::IdenticalSlowPath(
|
|||||||
case HloOpcode::kMinimum:
|
case HloOpcode::kMinimum:
|
||||||
case HloOpcode::kMultiply:
|
case HloOpcode::kMultiply:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kRemainder:
|
case HloOpcode::kRemainder:
|
||||||
@ -2135,6 +2138,7 @@ bool HloInstruction::IsElementwiseImpl(
|
|||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
case HloOpcode::kRsqrt:
|
case HloOpcode::kRsqrt:
|
||||||
@ -2600,6 +2604,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
|||||||
return visitor->HandleIsFinite(this);
|
return visitor->HandleIsFinite(this);
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
return visitor->HandleNot(this);
|
return visitor->HandleNot(this);
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
|
return visitor->HandlePopulationCount(this);
|
||||||
case HloOpcode::kBitcast:
|
case HloOpcode::kBitcast:
|
||||||
return visitor->HandleBitcast(this);
|
return visitor->HandleBitcast(this);
|
||||||
case HloOpcode::kBroadcast:
|
case HloOpcode::kBroadcast:
|
||||||
|
@ -108,6 +108,7 @@ namespace xla {
|
|||||||
V(kOutfeed, "outfeed", 2) \
|
V(kOutfeed, "outfeed", 2) \
|
||||||
V(kPad, "pad", 2) \
|
V(kPad, "pad", 2) \
|
||||||
V(kParameter, "parameter", 0) \
|
V(kParameter, "parameter", 0) \
|
||||||
|
V(kPopulationCount, "popcnt", 1) \
|
||||||
V(kPower, "power", 2) \
|
V(kPower, "power", 2) \
|
||||||
V(kReal, "real", 1) \
|
V(kReal, "real", 1) \
|
||||||
V(kRecv, "recv", 1) \
|
V(kRecv, "recv", 1) \
|
||||||
|
@ -744,6 +744,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
|||||||
case HloOpcode::kLog1p:
|
case HloOpcode::kLog1p:
|
||||||
case HloOpcode::kNot:
|
case HloOpcode::kNot:
|
||||||
case HloOpcode::kNegate:
|
case HloOpcode::kNegate:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kRsqrt:
|
case HloOpcode::kRsqrt:
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
|
@ -88,6 +88,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
|||||||
case HloOpcode::kXor:
|
case HloOpcode::kXor:
|
||||||
case HloOpcode::kOutfeed:
|
case HloOpcode::kOutfeed:
|
||||||
case HloOpcode::kPad:
|
case HloOpcode::kPad:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
case HloOpcode::kReplicaId:
|
case HloOpcode::kReplicaId:
|
||||||
|
@ -2016,6 +2016,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
|||||||
case HloOpcode::kOr:
|
case HloOpcode::kOr:
|
||||||
case HloOpcode::kXor:
|
case HloOpcode::kXor:
|
||||||
case HloOpcode::kPad:
|
case HloOpcode::kPad:
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
case HloOpcode::kPower:
|
case HloOpcode::kPower:
|
||||||
case HloOpcode::kReal:
|
case HloOpcode::kReal:
|
||||||
case HloOpcode::kReducePrecision:
|
case HloOpcode::kReducePrecision:
|
||||||
|
@ -308,6 +308,14 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
|||||||
HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
|
HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
|
||||||
}
|
}
|
||||||
return shape;
|
return shape;
|
||||||
|
case HloOpcode::kPopulationCount:
|
||||||
|
if (!ShapeUtil::ElementIsIntegral(shape)) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"Expected an integral element type in argument to PopulationCount "
|
||||||
|
"operation; got %s.",
|
||||||
|
PrimitiveType_Name(shape.element_type()));
|
||||||
|
}
|
||||||
|
return shape;
|
||||||
case HloOpcode::kSign:
|
case HloOpcode::kSign:
|
||||||
if (!ShapeUtil::ElementIsSigned(shape) &&
|
if (!ShapeUtil::ElementIsSigned(shape) &&
|
||||||
!ShapeUtil::ElementIsComplex(shape)) {
|
!ShapeUtil::ElementIsComplex(shape)) {
|
||||||
|
@ -1065,6 +1065,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
|
|||||||
ComputeAndCompareR1<uint32>(&builder, {}, {});
|
ComputeAndCompareR1<uint32>(&builder, {}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) {
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
auto a = ConstantR1<int32>(&builder, {0, 1, -15, 341});
|
||||||
|
PopulationCount(a);
|
||||||
|
ComputeAndCompareR1<int32>(&builder, {0, 1, 29, 5}, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) {
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
auto a = ConstantR2<int32>(&builder, {{0, 1}, {-15, 341}});
|
||||||
|
PopulationCount(a);
|
||||||
|
Array2D<int32> expected_array({{0, 1}, {29, 5}});
|
||||||
|
ComputeAndCompareR2<int32>(&builder, expected_array, {});
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) {
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
auto a = ConstantR2<int64>(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}});
|
||||||
|
PopulationCount(a);
|
||||||
|
Array2D<int64> expected_array({{0, 64}, {63, 62}});
|
||||||
|
ComputeAndCompareR2<int64>(&builder, expected_array, {});
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
|
XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
auto a = ConstantR1<int32>(
|
auto a = ConstantR1<int32>(
|
||||||
|
Loading…
Reference in New Issue
Block a user