[XLA] Add kLogistic HLO to allow for custom HLO lowering based on target.
This will enable target specific lowering for the logistic function to be performed when this HLO is used. PiperOrigin-RevId: 318159527 Change-Id: I453782fea99838fddd9039f63faa5c876cb7dec0
This commit is contained in:
parent
0780433b32
commit
9f292402a6
@ -1288,6 +1288,9 @@ if and only if the corresponding input element is finite.
|
||||
|
||||
<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
|
||||
|
||||
<b>`Logistic(operand)`</b> Element-wise logistic function computation `x ->
|
||||
logistic(x)`.
|
||||
|
||||
<b>`PopulationCount(operand)`</b> Computes the number of bits set in each
|
||||
element of `operand`.
|
||||
|
||||
|
@ -150,6 +150,9 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleRound(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
virtual Status HandleLogistic(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
virtual Status HandleSign(HloInstructionPtr hlo) {
|
||||
return HandleElementwiseUnary(hlo);
|
||||
}
|
||||
|
@ -99,12 +99,14 @@ Status HloCostAnalysis::HandleElementwiseOp(
|
||||
auto opcode = hlo_instruction->opcode();
|
||||
// We treat transcendental operations separately since one transcendental
|
||||
// operation can correspond to several floating point ops.
|
||||
// kLogistic is included in "trascendental" as it is implemented using
|
||||
// trascendental ops (tanh or exp).
|
||||
if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
|
||||
opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt ||
|
||||
opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
|
||||
opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
|
||||
opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
|
||||
opcode == HloOpcode::kAtan2) {
|
||||
opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
|
||||
opcode == HloOpcode::kSqrt || opcode == HloOpcode::kRsqrt ||
|
||||
opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin ||
|
||||
opcode == HloOpcode::kCos || opcode == HloOpcode::kExpm1 ||
|
||||
opcode == HloOpcode::kLog1p || opcode == HloOpcode::kAtan2) {
|
||||
current_properties_[kTranscendentalsKey] = computation_count;
|
||||
} else {
|
||||
// Note: transcendental operations are considered a separate category from
|
||||
|
@ -451,6 +451,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
return HandleNegate<ReturnT>(negate);
|
||||
}
|
||||
|
||||
Status HandleLogistic(HloInstruction* logistic) override {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
parent_->evaluated_[logistic],
|
||||
ElementWiseUnaryOp(logistic, [](ElementwiseT elem_operand) {
|
||||
return static_cast<ElementwiseT>(1) /
|
||||
(static_cast<ElementwiseT>(1) + std::exp(-elem_operand));
|
||||
}));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename NativeT,
|
||||
typename std::enable_if<std::is_integral<NativeT>::value>::type* =
|
||||
nullptr>
|
||||
|
@ -977,6 +977,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSlice:
|
||||
|
@ -833,6 +833,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
@ -1615,6 +1616,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kPopulationCount:
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
@ -1993,6 +1995,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
@ -2440,6 +2443,7 @@ bool HloInstruction::IsElementwiseImpl(
|
||||
case HloOpcode::kReal:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSqrt:
|
||||
@ -2854,6 +2858,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleBatchNormInference(this);
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
return visitor->HandleBatchNormGrad(this);
|
||||
case HloOpcode::kLogistic:
|
||||
return visitor->HandleLogistic(this);
|
||||
case HloOpcode::kSign:
|
||||
return visitor->HandleSign(this);
|
||||
case HloOpcode::kConstant:
|
||||
|
@ -98,6 +98,7 @@ namespace xla {
|
||||
V(kIsFinite, "is-finite", 1) \
|
||||
V(kLog, "log", 1) \
|
||||
V(kLog1p, "log-plus-one", 1) \
|
||||
V(kLogistic, "logistic", 1) \
|
||||
V(kAnd, "and", 2) \
|
||||
V(kNot, "not", 1) \
|
||||
V(kOr, "or", 2) \
|
||||
|
@ -888,6 +888,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kPopulationCount:
|
||||
|
@ -161,6 +161,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kGather:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kPower:
|
||||
|
@ -2193,6 +2193,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
|
@ -256,6 +256,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
case HloOpcode::kExpm1:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kRsqrt:
|
||||
case HloOpcode::kSqrt:
|
||||
case HloOpcode::kCbrt:
|
||||
|
@ -216,6 +216,7 @@ const HloInstruction* PickRepresentativeOperand(
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kLog1p:
|
||||
case HloOpcode::kLogistic:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kMultiply:
|
||||
|
Loading…
Reference in New Issue
Block a user