[XLA] Make GetLoopInductionVarTupleIdx external.
PiperOrigin-RevId: 237072928
This commit is contained in:
parent
9d703eecbf
commit
7a415783c6
@ -80,7 +80,7 @@ static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
|
||||
|
||||
// Tries to get the tuple index of the induction variable of a while loop.
|
||||
//
|
||||
// Checks that the loop condition and root both plumb the induction variable
|
||||
// Checks that the loop condition and body both plumb the induction variable
|
||||
// through the same tuple index, and that they both apply exactly one op to the
|
||||
// induction variable before deciding whether to do another loop iteration (in
|
||||
// the loop condition's case) or packing the induction variable into the result
|
||||
@ -96,8 +96,7 @@ static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
|
||||
// root = tuple(..., inc, ...) // inc is N'th operand of tuple().
|
||||
//
|
||||
// If so, returns N. Otherwise, returns nullopt.
|
||||
static optional<int64> GetLoopInductionVarTupleIdx(
|
||||
const HloInstruction* while_op) {
|
||||
optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) {
|
||||
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
|
||||
VLOG(2) << "Finding induction variable for loop "
|
||||
<< while_op->ToShortString();
|
||||
|
@ -35,6 +35,11 @@ absl::optional<int64> ComputeWhileLoopTripCount(
|
||||
// known, nullopt otherwise.
|
||||
absl::optional<int64> ComputeWhileLoopTripCountUpperBound(
|
||||
HloInstruction *while_op);
|
||||
|
||||
// Returns the tuple index of the loop induction variable if there is such an
|
||||
// induction variable detected. Otherwise returns nullopt.
|
||||
absl::optional<int64> GetLoopInductionVarTupleIdx(
|
||||
const HloInstruction *while_op);
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ANALYSIS_H_
|
||||
|
Loading…
Reference in New Issue
Block a user