[XLA] Add kLogistic HLO to allow for custom HLO lowering based on target.

This will enable target specific lowering for the logistic function
to be performed when this HLO is used.

PiperOrigin-RevId: 318159527
Change-Id: I453782fea99838fddd9039f63faa5c876cb7dec0
This commit is contained in:
Marcello Maggioni 2020-06-24 16:04:47 -07:00 committed by TensorFlower Gardener
parent 0780433b32
commit 9f292402a6
12 changed files with 36 additions and 5 deletions

View File

@ -1288,6 +1288,9 @@ if and only if the corresponding input element is finite.
<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`. <b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
<b>`Logistic(operand)`</b> Element-wise logistic function computation `x ->
logistic(x)`.
<b>`PopulationCount(operand)`</b> Computes the number of bits set in each <b>`PopulationCount(operand)`</b> Computes the number of bits set in each
element of `operand`. element of `operand`.

View File

@ -150,6 +150,9 @@ class DfsHloVisitorBase {
virtual Status HandleRound(HloInstructionPtr hlo) { virtual Status HandleRound(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo); return HandleElementwiseUnary(hlo);
} }
virtual Status HandleLogistic(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
virtual Status HandleSign(HloInstructionPtr hlo) { virtual Status HandleSign(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo); return HandleElementwiseUnary(hlo);
} }

View File

@ -99,12 +99,14 @@ Status HloCostAnalysis::HandleElementwiseOp(
auto opcode = hlo_instruction->opcode(); auto opcode = hlo_instruction->opcode();
// We treat transcendental operations separately since one transcendental // We treat transcendental operations separately since one transcendental
// operation can correspond to several floating point ops. // operation can correspond to several floating point ops.
// kLogistic is included in "trascendental" as it is implemented using
// trascendental ops (tanh or exp).
if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog || if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt || opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh || opcode == HloOpcode::kSqrt || opcode == HloOpcode::kRsqrt ||
opcode == HloOpcode::kSin || opcode == HloOpcode::kCos || opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin ||
opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p || opcode == HloOpcode::kCos || opcode == HloOpcode::kExpm1 ||
opcode == HloOpcode::kAtan2) { opcode == HloOpcode::kLog1p || opcode == HloOpcode::kAtan2) {
current_properties_[kTranscendentalsKey] = computation_count; current_properties_[kTranscendentalsKey] = computation_count;
} else { } else {
// Note: transcendental operations are considered a separate category from // Note: transcendental operations are considered a separate category from

View File

@ -451,6 +451,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleNegate<ReturnT>(negate); return HandleNegate<ReturnT>(negate);
} }
Status HandleLogistic(HloInstruction* logistic) override {
TF_ASSIGN_OR_RETURN(
parent_->evaluated_[logistic],
ElementWiseUnaryOp(logistic, [](ElementwiseT elem_operand) {
return static_cast<ElementwiseT>(1) /
(static_cast<ElementwiseT>(1) + std::exp(-elem_operand));
}));
return Status::OK();
}
template <typename NativeT, template <typename NativeT,
typename std::enable_if<std::is_integral<NativeT>::value>::type* = typename std::enable_if<std::is_integral<NativeT>::value>::type* =
nullptr> nullptr>

View File

@ -977,6 +977,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kShiftLeft: case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: case HloOpcode::kShiftRightLogical:
case HloOpcode::kLogistic:
case HloOpcode::kSign: case HloOpcode::kSign:
case HloOpcode::kSin: case HloOpcode::kSin:
case HloOpcode::kSlice: case HloOpcode::kSlice:

View File

@ -833,6 +833,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
case HloOpcode::kPopulationCount: case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign: case HloOpcode::kSign:
case HloOpcode::kSin: case HloOpcode::kSin:
case HloOpcode::kSqrt: case HloOpcode::kSqrt:
@ -1615,6 +1616,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kPopulationCount: case HloOpcode::kPopulationCount:
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign: case HloOpcode::kSign:
case HloOpcode::kSin: case HloOpcode::kSin:
case HloOpcode::kSqrt: case HloOpcode::kSqrt:
@ -1993,6 +1995,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kShiftLeft: case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: case HloOpcode::kShiftRightLogical:
case HloOpcode::kLogistic:
case HloOpcode::kSign: case HloOpcode::kSign:
case HloOpcode::kSin: case HloOpcode::kSin:
case HloOpcode::kSqrt: case HloOpcode::kSqrt:
@ -2440,6 +2443,7 @@ bool HloInstruction::IsElementwiseImpl(
case HloOpcode::kReal: case HloOpcode::kReal:
case HloOpcode::kReducePrecision: case HloOpcode::kReducePrecision:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kLogistic:
case HloOpcode::kSign: case HloOpcode::kSign:
case HloOpcode::kSin: case HloOpcode::kSin:
case HloOpcode::kSqrt: case HloOpcode::kSqrt:
@ -2854,6 +2858,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleBatchNormInference(this); return visitor->HandleBatchNormInference(this);
case HloOpcode::kBatchNormGrad: case HloOpcode::kBatchNormGrad:
return visitor->HandleBatchNormGrad(this); return visitor->HandleBatchNormGrad(this);
case HloOpcode::kLogistic:
return visitor->HandleLogistic(this);
case HloOpcode::kSign: case HloOpcode::kSign:
return visitor->HandleSign(this); return visitor->HandleSign(this);
case HloOpcode::kConstant: case HloOpcode::kConstant:

View File

@ -98,6 +98,7 @@ namespace xla {
V(kIsFinite, "is-finite", 1) \ V(kIsFinite, "is-finite", 1) \
V(kLog, "log", 1) \ V(kLog, "log", 1) \
V(kLog1p, "log-plus-one", 1) \ V(kLog1p, "log-plus-one", 1) \
V(kLogistic, "logistic", 1) \
V(kAnd, "and", 2) \ V(kAnd, "and", 2) \
V(kNot, "not", 1) \ V(kNot, "not", 1) \
V(kOr, "or", 2) \ V(kOr, "or", 2) \

View File

@ -888,6 +888,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kFloor: case HloOpcode::kFloor:
case HloOpcode::kLog: case HloOpcode::kLog:
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kNot: case HloOpcode::kNot:
case HloOpcode::kNegate: case HloOpcode::kNegate:
case HloOpcode::kPopulationCount: case HloOpcode::kPopulationCount:

View File

@ -161,6 +161,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kGather: case HloOpcode::kGather:
case HloOpcode::kLog: case HloOpcode::kLog:
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kMap: case HloOpcode::kMap:
case HloOpcode::kParameter: case HloOpcode::kParameter:
case HloOpcode::kPower: case HloOpcode::kPower:

View File

@ -2193,6 +2193,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kIsFinite: case HloOpcode::kIsFinite:
case HloOpcode::kLog: case HloOpcode::kLog:
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kMap: case HloOpcode::kMap:
case HloOpcode::kMaximum: case HloOpcode::kMaximum:
case HloOpcode::kMinimum: case HloOpcode::kMinimum:

View File

@ -256,6 +256,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
case HloOpcode::kExpm1: case HloOpcode::kExpm1:
case HloOpcode::kLog: case HloOpcode::kLog:
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kRsqrt: case HloOpcode::kRsqrt:
case HloOpcode::kSqrt: case HloOpcode::kSqrt:
case HloOpcode::kCbrt: case HloOpcode::kCbrt:

View File

@ -216,6 +216,7 @@ const HloInstruction* PickRepresentativeOperand(
case HloOpcode::kIsFinite: case HloOpcode::kIsFinite:
case HloOpcode::kLog: case HloOpcode::kLog:
case HloOpcode::kLog1p: case HloOpcode::kLog1p:
case HloOpcode::kLogistic:
case HloOpcode::kMaximum: case HloOpcode::kMaximum:
case HloOpcode::kMinimum: case HloOpcode::kMinimum:
case HloOpcode::kMultiply: case HloOpcode::kMultiply: