[XLA] Make GetLoopInductionVarTupleIdx external.

PiperOrigin-RevId: 237072928
This commit is contained in:
Bixia Zheng 2019-03-06 10:25:33 -08:00 committed by TensorFlower Gardener
parent 9d703eecbf
commit 7a415783c6
2 changed files with 7 additions and 3 deletions

View File

@ -80,7 +80,7 @@ static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
// Tries to get the tuple index of the induction variable of a while loop.
//
// Checks that the loop condition and root both plumb the induction variable
// Checks that the loop condition and body both plumb the induction variable
// through the same tuple index, and that they both apply exactly one op to the
// induction variable before deciding whether to do another loop iteration (in
// the loop condition's case) or packing the induction variable into the result
@ -96,8 +96,7 @@ static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
// root = tuple(..., inc, ...) // inc is N'th operand of tuple().
//
// If so, returns N. Otherwise, returns nullopt.
static optional<int64> GetLoopInductionVarTupleIdx(
const HloInstruction* while_op) {
optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
VLOG(2) << "Finding induction variable for loop "
<< while_op->ToShortString();

View File

@ -35,6 +35,11 @@ absl::optional<int64> ComputeWhileLoopTripCount(
// known, nullopt otherwise.
absl::optional<int64> ComputeWhileLoopTripCountUpperBound(
HloInstruction *while_op);
// Returns the tuple index of the loop induction variable if there is such an
// induction variable detected. Otherwise returns nullopt.
absl::optional<int64> GetLoopInductionVarTupleIdx(
const HloInstruction *while_op);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_