diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index 4d483e2b78e..56aa3feefaa 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -1288,6 +1288,9 @@ if and only if the corresponding input element is finite. `LogicalNot(operand)` Element-wise logical not `x -> !(x)`. +`Logistic(operand)` Element-wise logistic function computation `x -> +logistic(x)`. + `PopulationCount(operand)` Computes the number of bits set in each element of `operand`. diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index bdaac32a0e5..b0def1a2dd8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -150,6 +150,9 @@ class DfsHloVisitorBase { virtual Status HandleRound(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } + virtual Status HandleLogistic(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } virtual Status HandleSign(HloInstructionPtr hlo) { return HandleElementwiseUnary(hlo); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 8a31bc5fef4..814643718ba 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -99,12 +99,14 @@ Status HloCostAnalysis::HandleElementwiseOp( auto opcode = hlo_instruction->opcode(); // We treat transcendental operations separately since one transcendental // operation can correspond to several floating point ops. + // kLogistic is included in "trascendental" as it is implemented using + // trascendental ops (tanh or exp). if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog || - opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt || - opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh || - opcode == HloOpcode::kSin || opcode == HloOpcode::kCos || - opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p || - opcode == HloOpcode::kAtan2) { + opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower || + opcode == HloOpcode::kSqrt || opcode == HloOpcode::kRsqrt || + opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin || + opcode == HloOpcode::kCos || opcode == HloOpcode::kExpm1 || + opcode == HloOpcode::kLog1p || opcode == HloOpcode::kAtan2) { current_properties_[kTranscendentalsKey] = computation_count; } else { // Note: transcendental operations are considered a separate category from diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h index 3dc9cc24734..1a154f32a6f 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h @@ -451,6 +451,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault { return HandleNegate(negate); } + Status HandleLogistic(HloInstruction* logistic) override { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[logistic], + ElementWiseUnaryOp(logistic, [](ElementwiseT elem_operand) { + return static_cast(1) / + (static_cast(1) + std::exp(-elem_operand)); + })); + return Status::OK(); + } + template ::value>::type* = nullptr> diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index ad21efa13c9..a50af6bf1b9 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -977,6 +977,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + case HloOpcode::kLogistic: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 6de76c1cc63..e9a04583bdf 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -833,6 +833,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: @@ -1615,6 +1616,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kPopulationCount: case HloOpcode::kReal: case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: @@ -1993,6 +1995,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: + case HloOpcode::kLogistic: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: @@ -2440,6 +2443,7 @@ bool HloInstruction::IsElementwiseImpl( case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kRsqrt: + case HloOpcode::kLogistic: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSqrt: @@ -2854,6 +2858,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleBatchNormInference(this); case HloOpcode::kBatchNormGrad: return visitor->HandleBatchNormGrad(this); + case HloOpcode::kLogistic: + return visitor->HandleLogistic(this); case HloOpcode::kSign: return visitor->HandleSign(this); case HloOpcode::kConstant: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 92359bcbdac..1625d0bbae4 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -98,6 +98,7 @@ namespace xla { V(kIsFinite, "is-finite", 1) \ V(kLog, "log", 1) \ V(kLog1p, "log-plus-one", 1) \ + V(kLogistic, "logistic", 1) \ V(kAnd, "and", 2) \ V(kNot, "not", 1) \ V(kOr, "or", 2) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 86475ce76f4..22cd34f3378 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -888,6 +888,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kFloor: case HloOpcode::kLog: case HloOpcode::kLog1p: + case HloOpcode::kLogistic: case HloOpcode::kNot: case HloOpcode::kNegate: case HloOpcode::kPopulationCount: diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 02966cc2bf2..8d8930615b2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -161,6 +161,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kGather: case HloOpcode::kLog: case HloOpcode::kLog1p: + case HloOpcode::kLogistic: case HloOpcode::kMap: case HloOpcode::kParameter: case HloOpcode::kPower: diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index a35ba140e86..3c48668e742 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2193,6 +2193,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLog1p: + case HloOpcode::kLogistic: case HloOpcode::kMap: case HloOpcode::kMaximum: case HloOpcode::kMinimum: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index bb4a38ded1e..40a28d90f0a 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -256,6 +256,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, case HloOpcode::kExpm1: case HloOpcode::kLog: case HloOpcode::kLog1p: + case HloOpcode::kLogistic: case HloOpcode::kRsqrt: case HloOpcode::kSqrt: case HloOpcode::kCbrt: diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index 2eba9279df4..46ef132c1c0 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -216,6 +216,7 @@ const HloInstruction* PickRepresentativeOperand( case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLog1p: + case HloOpcode::kLogistic: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kMultiply: