Add support for PopulationCount.

PiperOrigin-RevId: 239066669
This commit is contained in:
A. Unique TensorFlower 2019-03-18 14:52:18 -07:00 committed by TensorFlower Gardener
parent c8d3f6f590
commit 73cd0f7199
14 changed files with 93 additions and 0 deletions

View File

@ -3167,6 +3167,10 @@ XlaOp Not(const XlaOp& operand) {
return operand.builder()->UnaryOp(HloOpcode::kNot, operand); return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
} }
XlaOp PopulationCount(const XlaOp& operand) {
return operand.builder()->UnaryOp(HloOpcode::kPopulationCount, operand);
}
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) { absl::Span<const int64> broadcast_dimensions) {
return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,

View File

@ -852,6 +852,7 @@ class XlaBuilder {
friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs, friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions); absl::Span<const int64> broadcast_dimensions);
friend XlaOp Not(const XlaOp& operand); friend XlaOp Not(const XlaOp& operand);
friend XlaOp PopulationCount(const XlaOp& operand);
friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions); absl::Span<const int64> broadcast_dimensions);
friend XlaOp ShiftRightArithmetic( friend XlaOp ShiftRightArithmetic(
@ -1519,6 +1520,8 @@ XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
XlaOp Not(const XlaOp& operand); XlaOp Not(const XlaOp& operand);
XlaOp PopulationCount(const XlaOp& operand);
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs, XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions = {}); absl::Span<const int64> broadcast_dimensions = {});
XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs, XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,

View File

@ -1237,6 +1237,9 @@ if and only if the corresponding input element is finite.
<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`. <b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
<b>`PopulationCount(operand)`</b> Computes the number of bits set in each
element of `operand`.
<b>`Neg(operand)`</b> Element-wise negation `x -> -x`. <b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where <b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where

View File

@ -199,6 +199,9 @@ class DfsHloVisitorBase {
virtual Status HandleXor(HloInstructionPtr hlo) { virtual Status HandleXor(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo); return HandleElementwiseBinary(hlo);
} }
virtual Status HandlePopulationCount(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
virtual Status HandleShiftLeft(HloInstructionPtr hlo) { virtual Status HandleShiftLeft(HloInstructionPtr hlo) {
return HandleElementwiseBinary(hlo); return HandleElementwiseBinary(hlo);
} }

View File

@ -326,6 +326,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
} }
return Unimplemented("unary op Not is not defined for type '%d'", type); return Unimplemented("unary op Not is not defined for type '%d'", type);
} }
case HloOpcode::kPopulationCount: {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctpop,
{operand_value},
{operand_value->getType()}, b_);
}
default: default:
return Unimplemented("unary integer op '%s'", return Unimplemented("unary integer op '%s'",
HloOpcodeString(op->opcode())); HloOpcodeString(op->opcode()));
@ -2219,6 +2224,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kNot: case HloOpcode::kNot:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kSign: case HloOpcode::kSign:

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_TYPED_VISITOR_H_
#include <bitset>
#include <cmath> #include <cmath>
#include <type_traits> #include <type_traits>
@ -2482,6 +2483,37 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleClz<ElementwiseT>(clz); return HandleClz<ElementwiseT>(clz);
} }
// Enable Popcnt only for int32, uint32, int64 and uint64.
template <typename NativeT,
typename std::enable_if<
!(std::is_same<NativeT, uint32>::value ||
std::is_same<NativeT, int32>::value ||
std::is_same<NativeT, uint64>::value ||
std::is_same<NativeT, int64>::value)>::type* = nullptr>
Status HandlePopulationCount(HloInstruction* popcnt) {
return UnsupportedTypeError(popcnt);
}
template <typename NativeT,
typename std::enable_if<
std::is_same<NativeT, uint32>::value ||
std::is_same<NativeT, int32>::value ||
std::is_same<NativeT, uint64>::value ||
std::is_same<NativeT, int64>::value>::type* = nullptr>
Status HandlePopulationCount(HloInstruction* popcnt) {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[popcnt],
ElementWiseUnaryOp(popcnt, [](ElementwiseT elem_operand) {
return std::bitset<CHAR_BIT * sizeof elem_operand>(elem_operand)
.count();
}));
return Status::OK();
}
Status HandlePopulationCount(HloInstruction* popcnt) override {
return HandlePopulationCount<ElementwiseT>(popcnt);
}
template <typename NativeT, typename std::enable_if<std::is_floating_point< template <typename NativeT, typename std::enable_if<std::is_floating_point<
NativeT>::value>::type* = nullptr> NativeT>::value>::type* = nullptr>
Status HandleSin(HloInstruction* sin) { Status HandleSin(HloInstruction* sin) {

View File

@ -947,6 +947,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kMultiply: case HloOpcode::kMultiply:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kNot: case HloOpcode::kNot:
case HloOpcode::kPopulationCount:
case HloOpcode::kOr: case HloOpcode::kOr:
case HloOpcode::kXor: case HloOpcode::kXor:
case HloOpcode::kPower: case HloOpcode::kPower:

View File

@ -703,6 +703,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kNot: case HloOpcode::kNot:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kSign: case HloOpcode::kSign:
@ -1408,6 +1409,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kNot: case HloOpcode::kNot:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kSign: case HloOpcode::kSign:
@ -1754,6 +1756,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kMinimum: case HloOpcode::kMinimum:
case HloOpcode::kMultiply: case HloOpcode::kMultiply:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kPower: case HloOpcode::kPower:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kRemainder: case HloOpcode::kRemainder:
@ -2135,6 +2138,7 @@ bool HloInstruction::IsElementwiseImpl(
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kNot: case HloOpcode::kNot:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kReducePrecision: case HloOpcode::kReducePrecision:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
@ -2600,6 +2604,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleIsFinite(this); return visitor->HandleIsFinite(this);
case HloOpcode::kNot: case HloOpcode::kNot:
return visitor->HandleNot(this); return visitor->HandleNot(this);
case HloOpcode::kPopulationCount:
return visitor->HandlePopulationCount(this);
case HloOpcode::kBitcast: case HloOpcode::kBitcast:
return visitor->HandleBitcast(this); return visitor->HandleBitcast(this);
case HloOpcode::kBroadcast: case HloOpcode::kBroadcast:

View File

@ -108,6 +108,7 @@ namespace xla {
V(kOutfeed, "outfeed", 2) \ V(kOutfeed, "outfeed", 2) \
V(kPad, "pad", 2) \ V(kPad, "pad", 2) \
V(kParameter, "parameter", 0) \ V(kParameter, "parameter", 0) \
V(kPopulationCount, "popcnt", 1) \
V(kPower, "power", 2) \ V(kPower, "power", 2) \
V(kReal, "real", 1) \ V(kReal, "real", 1) \
V(kRecv, "recv", 1) \ V(kRecv, "recv", 1) \

View File

@ -744,6 +744,7 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kNot: case HloOpcode::kNot:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kSign: case HloOpcode::kSign:

View File

@ -88,6 +88,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kXor: case HloOpcode::kXor:
case HloOpcode::kOutfeed: case HloOpcode::kOutfeed:
case HloOpcode::kPad: case HloOpcode::kPad:
case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kReducePrecision: case HloOpcode::kReducePrecision:
case HloOpcode::kReplicaId: case HloOpcode::kReplicaId:

View File

@ -2016,6 +2016,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kOr: case HloOpcode::kOr:
case HloOpcode::kXor: case HloOpcode::kXor:
case HloOpcode::kPad: case HloOpcode::kPad:
case HloOpcode::kPopulationCount:
case HloOpcode::kPower: case HloOpcode::kPower:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kReducePrecision: case HloOpcode::kReducePrecision:

View File

@ -308,6 +308,14 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type())); HloOpcodeString(opcode), PrimitiveType_Name(shape.element_type()));
} }
return shape; return shape;
case HloOpcode::kPopulationCount:
if (!ShapeUtil::ElementIsIntegral(shape)) {
return InvalidArgument(
"Expected an integral element type in argument to PopulationCount "
"operation; got %s.",
PrimitiveType_Name(shape.element_type()));
}
return shape;
case HloOpcode::kSign: case HloOpcode::kSign:
if (!ShapeUtil::ElementIsSigned(shape) && if (!ShapeUtil::ElementIsSigned(shape) &&
!ShapeUtil::ElementIsComplex(shape)) { !ShapeUtil::ElementIsComplex(shape)) {

View File

@ -1065,6 +1065,29 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) {
ComputeAndCompareR1<uint32>(&builder, {}, {}); ComputeAndCompareR1<uint32>(&builder, {}, {});
} }
XLA_TEST_F(ArrayElementwiseOpTest, PopcntR1) {
XlaBuilder builder(TestName());
auto a = ConstantR1<int32>(&builder, {0, 1, -15, 341});
PopulationCount(a);
ComputeAndCompareR1<int32>(&builder, {0, 1, 29, 5}, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, PopcntR2) {
XlaBuilder builder(TestName());
auto a = ConstantR2<int32>(&builder, {{0, 1}, {-15, 341}});
PopulationCount(a);
Array2D<int32> expected_array({{0, 1}, {29, 5}});
ComputeAndCompareR2<int32>(&builder, expected_array, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, PopcntS64) {
XlaBuilder builder(TestName());
auto a = ConstantR2<int64>(&builder, {{0, -1}, {INT64_MAX, INT64_MAX - 1}});
PopulationCount(a);
Array2D<int64> expected_array({{0, 64}, {63, 62}});
ComputeAndCompareR2<int64>(&builder, expected_array, {});
}
XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) {
XlaBuilder builder(TestName()); XlaBuilder builder(TestName());
auto a = ConstantR1<int32>( auto a = ConstantR1<int32>(