diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md
index 4d483e2b78e..56aa3feefaa 100644
--- a/tensorflow/compiler/xla/g3doc/operation_semantics.md
+++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md
@@ -1288,6 +1288,9 @@ if and only if the corresponding input element is finite.
`LogicalNot(operand)` Element-wise logical not `x -> !(x)`.
+`Logistic(operand)` Element-wise logistic function computation `x ->
+logistic(x)`.
+
`PopulationCount(operand)` Computes the number of bits set in each
element of `operand`.
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index bdaac32a0e5..b0def1a2dd8 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -150,6 +150,9 @@ class DfsHloVisitorBase {
virtual Status HandleRound(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
+ virtual Status HandleLogistic(HloInstructionPtr hlo) {
+ return HandleElementwiseUnary(hlo);
+ }
virtual Status HandleSign(HloInstructionPtr hlo) {
return HandleElementwiseUnary(hlo);
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 8a31bc5fef4..814643718ba 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -99,12 +99,14 @@ Status HloCostAnalysis::HandleElementwiseOp(
auto opcode = hlo_instruction->opcode();
// We treat transcendental operations separately since one transcendental
// operation can correspond to several floating point ops.
+ // kLogistic is included in "trascendental" as it is implemented using
+ // trascendental ops (tanh or exp).
if (opcode == HloOpcode::kExp || opcode == HloOpcode::kLog ||
- opcode == HloOpcode::kPower || opcode == HloOpcode::kSqrt ||
- opcode == HloOpcode::kRsqrt || opcode == HloOpcode::kTanh ||
- opcode == HloOpcode::kSin || opcode == HloOpcode::kCos ||
- opcode == HloOpcode::kExpm1 || opcode == HloOpcode::kLog1p ||
- opcode == HloOpcode::kAtan2) {
+ opcode == HloOpcode::kLogistic || opcode == HloOpcode::kPower ||
+ opcode == HloOpcode::kSqrt || opcode == HloOpcode::kRsqrt ||
+ opcode == HloOpcode::kTanh || opcode == HloOpcode::kSin ||
+ opcode == HloOpcode::kCos || opcode == HloOpcode::kExpm1 ||
+ opcode == HloOpcode::kLog1p || opcode == HloOpcode::kAtan2) {
current_properties_[kTranscendentalsKey] = computation_count;
} else {
// Note: transcendental operations are considered a separate category from
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 3dc9cc24734..1a154f32a6f 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -451,6 +451,16 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
return HandleNegate(negate);
}
+ Status HandleLogistic(HloInstruction* logistic) override {
+ TF_ASSIGN_OR_RETURN(
+ parent_->evaluated_[logistic],
+ ElementWiseUnaryOp(logistic, [](ElementwiseT elem_operand) {
+ return static_cast(1) /
+ (static_cast(1) + std::exp(-elem_operand));
+ }));
+ return Status::OK();
+ }
+
template ::value>::type* =
nullptr>
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index ad21efa13c9..a50af6bf1b9 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -977,6 +977,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
+ case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSlice:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 6de76c1cc63..e9a04583bdf 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -833,6 +833,7 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
+ case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
@@ -1615,6 +1616,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands(
case HloOpcode::kPopulationCount:
case HloOpcode::kReal:
case HloOpcode::kRsqrt:
+ case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
@@ -1993,6 +1995,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical:
+ case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
@@ -2440,6 +2443,7 @@ bool HloInstruction::IsElementwiseImpl(
case HloOpcode::kReal:
case HloOpcode::kReducePrecision:
case HloOpcode::kRsqrt:
+ case HloOpcode::kLogistic:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSqrt:
@@ -2854,6 +2858,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) {
return visitor->HandleBatchNormInference(this);
case HloOpcode::kBatchNormGrad:
return visitor->HandleBatchNormGrad(this);
+ case HloOpcode::kLogistic:
+ return visitor->HandleLogistic(this);
case HloOpcode::kSign:
return visitor->HandleSign(this);
case HloOpcode::kConstant:
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 92359bcbdac..1625d0bbae4 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -98,6 +98,7 @@ namespace xla {
V(kIsFinite, "is-finite", 1) \
V(kLog, "log", 1) \
V(kLog1p, "log-plus-one", 1) \
+ V(kLogistic, "logistic", 1) \
V(kAnd, "and", 2) \
V(kNot, "not", 1) \
V(kOr, "or", 2) \
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 86475ce76f4..22cd34f3378 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -888,6 +888,7 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
+ case HloOpcode::kLogistic:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kPopulationCount:
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 02966cc2bf2..8d8930615b2 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -161,6 +161,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kGather:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
+ case HloOpcode::kLogistic:
case HloOpcode::kMap:
case HloOpcode::kParameter:
case HloOpcode::kPower:
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index a35ba140e86..3c48668e742 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -2193,6 +2193,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
+ case HloOpcode::kLogistic:
case HloOpcode::kMap:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index bb4a38ded1e..40a28d90f0a 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -256,6 +256,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape,
case HloOpcode::kExpm1:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
+ case HloOpcode::kLogistic:
case HloOpcode::kRsqrt:
case HloOpcode::kSqrt:
case HloOpcode::kCbrt:
diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc
index 2eba9279df4..46ef132c1c0 100644
--- a/tensorflow/compiler/xla/service/sharding_propagation.cc
+++ b/tensorflow/compiler/xla/service/sharding_propagation.cc
@@ -216,6 +216,7 @@ const HloInstruction* PickRepresentativeOperand(
case HloOpcode::kIsFinite:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
+ case HloOpcode::kLogistic:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kMultiply: