Adding support to identify a subset of possible auxiliary

induction variables (AIV). Specifically, candidates are gtes, e.g.,
gte(param0, N). We check if the loop body plumbs the AIV
through the same tuple index at root, and that ops involving AIV
involve constants.
   op2 = op(constants, gte(param0, N), constants)
   op3 = op(constants, f(op2, gte(param0, N), constants)
   root = tuple(..., op3, ...)
Further, the ops are restricted to basic math ops (+,-,*,/).

PiperOrigin-RevId: 353357828
Change-Id: If2bbd7a4d4758982100c2a29e618ac40abf1de1e
This commit is contained in:
A. Unique TensorFlower 2021-01-22 18:36:30 -08:00 committed by TensorFlower Gardener
parent 605c0759d1
commit 62a1fad4fa
6 changed files with 329 additions and 0 deletions

View File

@ -2489,6 +2489,7 @@ cc_library(
":hlo",
":hlo_evaluator",
":pattern_matcher",
"//tensorflow/compiler/xla/service:hlo_reachability",
"@com_google_absl//absl/base",
"@com_google_absl//absl/types:optional",
],

View File

@ -73,6 +73,23 @@ void HloReachabilityMap::SetReachable(Index a, Index b) {
GetBitVector(b).Set(a.v);
}
std::unique_ptr<HloReachabilityMap> HloReachabilityMap::BuildWithRestrictions(
const HloComputation* computation,
absl::FunctionRef<void(const HloInstruction*,
std::vector<HloInstruction*>*)>
add_dependencies) {
const auto& all = computation->MakeInstructionPostOrder();
auto result = absl::make_unique<HloReachabilityMap>(all);
std::vector<HloInstruction*> inputs;
for (const HloInstruction* hlo : all) {
inputs.clear();
add_dependencies(hlo, &inputs);
result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
}
std::unique_ptr<HloReachabilityMap> HloReachabilityMap::Build(
const HloComputation* computation) {
const auto& all = computation->MakeInstructionPostOrder();

View File

@ -56,6 +56,18 @@ class HloReachabilityMap {
static std::unique_ptr<HloReachabilityMap> Build(
const HloComputation* computation);
// Similar to the above Build operation except that it tries to identify
// paths between instructions that do not contain control instructions
// and multiple operands, i.e., b is_reachable a == true iff
// b = f(f(f(f(f(a), constant), constant), constant).
// Further, the only ops allowed in a path are basic math operations such
// as add, sub, mul, div.
static std::unique_ptr<HloReachabilityMap> BuildWithRestrictions(
const HloComputation* computation,
absl::FunctionRef<void(const HloInstruction*,
std::vector<HloInstruction*>*)>
add_dependencies);
// Set the reachability set of 'instruction' to the union of the reachability
// sets of 'inputs'. Upon return, IsReachable(x, instruction) where
// 'x' is not 'instruction' will return true iff IsReachable(x, input) is true

View File

@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
#include "absl/base/casts.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_reachability.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
namespace xla {
@ -78,6 +80,146 @@ static optional<int64> GetGTEOperandIndex(const HloInstruction* instr,
return tuple_idx;
}
// The below function identifies a subset of all possible auxiliary
// induction variables (AIV). Specifically, candidates are gtes, e.g.,
// gte(param0, N)
// The function checks if the loop body plumbs the AIV
// through the same tuple index at root, and that ops involving AIV
// involve constants.
// op2 = op(constants, gte(param0, N), constants)
// op3 = op(constants, f(op2, gte(param0, N), constants)
// op4 = op(constants, f(op3, constants)
// root = tuple(..., op4, ...)
// Further, the ops are restricted to basic math ops (+,-,*,/).
// Finally, loop invariant GTEs are excluded from AIVs.
// We can expand the ops category/nature of AIVs as needed.
std::vector<const HloInstruction*> GetAuxiliaryLoopInductionVars(
const HloInstruction* while_op) {
std::vector<const HloInstruction*> aux_ind_gte;
CHECK_EQ(while_op->opcode(), HloOpcode::kWhile);
auto* while_body = while_op->while_body();
auto* while_body_param = while_body->parameter_instruction(0);
VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString();
VLOG(2) << "the parameter instr:" << while_body_param->ToShortString();
VLOG(2) << "the parameter user count:" << while_body_param->users().size();
if (while_body_param == nullptr) return aux_ind_gte;
// candidates_pairs = pair<inst, inst>(
// operands of the root while body,
// GTE only operands that index into the same position in the parameter)
// for each candidate_pair (x, y)
// find all paths between x and y,
// each paths should satisfy the above listed criterion
// index that x and y used is added as a aux variable index
std::map<int64, const HloInstruction*> extractions;
for (const HloInstruction* indx_instr : while_body_param->users()) {
if (indx_instr->opcode() != HloOpcode::kGetTupleElement) {
continue;
}
auto it = extractions.find(indx_instr->tuple_index());
// if we find two extractions at the same index, we ignore such
// a candidate
if (it != extractions.end()) {
it->second = nullptr;
VLOG(2) << "two extractions at same index:" << indx_instr->ToString();
} else {
extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr));
VLOG(2) << "inserting extraction :" << indx_instr->ToString();
}
}
VLOG(2) << "total extractions size:" << extractions.size() << std::endl;
if (extractions.empty()) {
return aux_ind_gte;
}
auto* while_body_root = while_body->root_instruction();
if (while_body_root->opcode() != HloOpcode::kTuple) {
VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString();
return aux_ind_gte;
}
int64 index = -1;
std::map<int64, const HloInstruction*> insertions;
for (const HloInstruction* operand : while_body_root->operands()) {
index++;
if (!operand->IsConstant()) {
auto it = insertions.find(index);
if (it != insertions.end()) {
it->second = nullptr;
VLOG(2) << "two insertions at same index:" << operand->ToString();
} else {
insertions.insert(std::make_pair(index, operand));
VLOG(2) << "inserting insertions:" << operand->ToString();
}
}
}
if (insertions.empty()) {
return aux_ind_gte;
}
std::map<int64, std::pair<const HloInstruction*, const HloInstruction*>>
candidate_pairs;
for (; index >= 0; --index) {
const HloInstruction *ext, *inst;
ext = (extractions.find(index) != extractions.end())
? extractions.find(index)->second
: nullptr;
inst = (insertions.find(index) != insertions.end())
? insertions.find(index)->second
: nullptr;
if (ext != nullptr && inst != nullptr) {
// Filter out trivial aux, i.e., extract directly to an insert.
if (ext != inst) {
candidate_pairs.insert(
std::make_pair(index, std::make_pair(ext, inst)));
}
}
}
VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl;
// Passed to ReachabilityMap to decide the type of produce-consumer edges
// along the reachability path.
const auto add_dependencies = [](const HloInstruction* hlo,
std::vector<HloInstruction*>* inputs) {
HloInstruction* non_const_operand = nullptr;
int num_non_constants = 0;
for (HloInstruction* operand : hlo->operands()) {
if (!operand->IsConstant()) {
num_non_constants++;
non_const_operand = operand;
}
}
if (num_non_constants == 1 &&
(hlo->opcode() == HloOpcode::kGetTupleElement ||
hlo->opcode() == HloOpcode::kAdd ||
hlo->opcode() == HloOpcode::kMultiply ||
hlo->opcode() == HloOpcode::kDivide ||
hlo->opcode() == HloOpcode::kSubtract)) {
inputs->push_back(non_const_operand);
}
};
std::unique_ptr<HloReachabilityMap> hrm =
HloReachabilityMap::BuildWithRestrictions(
while_body,
absl::FunctionRef<void(const HloInstruction* hlo,
std::vector<HloInstruction*>* inputs)>(
add_dependencies));
for (auto candidates : candidate_pairs) {
VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString()
<< "*************" << (candidates.second.second)->ToString()
<< std::endl;
if (hrm->IsReachable(candidates.second.first, candidates.second.second)) {
aux_ind_gte.push_back(candidates.second.first);
VLOG(2) << "YES";
} else {
VLOG(2) << "NO";
}
}
VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size();
return aux_ind_gte;
}
// Tries to get the tuple index of the induction variable of a while loop.
//
// Checks that the loop condition and body both plumb the induction variable

View File

@ -36,6 +36,11 @@ absl::optional<int64> ComputeWhileLoopTripCount(
absl::optional<int64> ComputeWhileLoopTripCountUpperBound(
HloInstruction *while_op);
// The below function identifies a subset of all possible auxiliary
// induction variables (AIV). Specifically, candidates are gtes, e.g.,
// gte(param0, N)
std::vector<const HloInstruction *> GetAuxiliaryLoopInductionVars(
const HloInstruction *while_op);
// Returns the tuple index of the loop induction variable if there is such an
// induction variable detected. Otherwise returns nullopt.
absl::optional<int64> GetLoopInductionVarTupleIdx(

View File

@ -120,5 +120,157 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) {
EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42);
}
TEST_F(WhileLoopAnalysisTest, NoAIVNoConstChain) {
const char* const kHloModule = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[2], s32[], s32[]) parameter(0)
val1 = f32[2] get-tuple-element(p_body), index=0
val2 = s32[] get-tuple-element(p_body), index=1
val3 = s32[] get-tuple-element(p_body), index=2
add = s32[] add(val2, val3)
sub = s32[] subtract(add, val3)
ROOT root = (f32[2], s32[], s32[]) tuple(val1, add, sub)
}
condition {
p_cond = (f32[2], s32[], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=1
const = s32[] constant(42)
ROOT result = pred[] compare(gte, const), direction=EQ
}
ENTRY entry {
param.0 = f32[2] parameter(0)
param.1 = s32[] parameter(1)
param.2 = s32[] parameter(2)
while_init = (f32[2], s32[], s32[]) tuple(param.0, param.1, param.2)
ROOT while = (f32[2], s32[], s32[]) while(while_init), condition=condition, body=body
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::vector<const HloInstruction*> aux_indices =
GetAuxiliaryLoopInductionVars(while_op);
EXPECT_EQ(aux_indices.size(), 0);
}
TEST_F(WhileLoopAnalysisTest, AIVMultiChain) {
const char* const kHloModule = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[2], s32[]) parameter(0)
val1 = f32[2] get-tuple-element(p_body), index=0
val2 = s32[] get-tuple-element(p_body), index=1
const.1 = s32[] constant(42)
const.2 = s32[] constant(42)
const.3 = s32[] constant(42)
add = s32[] add(val2, const.1)
sub = s32[] subtract(add, const.2)
mul = s32[] multiply(sub, const.3)
ROOT root = (f32[2], s32[]) tuple(val1, mul)
}
condition {
p_cond = (f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=1
const = s32[] constant(42)
ROOT result = pred[] compare(gte, const), direction=EQ
}
ENTRY entry {
param.0 = f32[2] parameter(0)
param.1 = s32[] parameter(1)
while_init = (f32[2], s32[]) tuple(param.0, param.1)
ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::vector<const HloInstruction*> aux_indices =
GetAuxiliaryLoopInductionVars(while_op);
EXPECT_EQ(aux_indices.size(), 1);
EXPECT_EQ(aux_indices[0]->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(aux_indices[0]->tuple_index(), 1);
}
TEST_F(WhileLoopAnalysisTest, NoAIV) {
const char* const kHloModule = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[2], s32[]) parameter(0)
val1 = f32[2] get-tuple-element(p_body), index=0
val2 = s32[] get-tuple-element(p_body), index=1
add = s32[] add(val2, val2)
const.1 = s32[] constant(42)
mul = s32[] multiply(add, const.1)
div = s32[] divide(mul, add)
ROOT root = (f32[2], s32[]) tuple(val1, div)
}
condition {
p_cond = (f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=1
const = s32[] constant(42)
ROOT result = pred[] compare(gte, const), direction=EQ
}
ENTRY entry {
param.0 = f32[2] parameter(0)
param.1 = s32[] parameter(1)
while_init = (f32[2], s32[]) tuple(param.0, param.1)
ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::vector<const HloInstruction*> aux_indices =
GetAuxiliaryLoopInductionVars(while_op);
EXPECT_EQ(aux_indices.size(), 0);
}
TEST_F(WhileLoopAnalysisTest, AIVNoChain) {
const char* const kHloModule = R"(
HloModule ModuleWithWhile
body {
p_body = (f32[2], s32[]) parameter(0)
val1 = f32[2] get-tuple-element(p_body), index=0
val2 = s32[] get-tuple-element(p_body), index=1
const = s32[] constant(42)
add = s32[] add(val2, const)
ROOT root = (f32[2], s32[]) tuple(val1, add)
}
condition {
p_cond = (f32[2], s32[]) parameter(0)
gte = s32[] get-tuple-element(p_cond), index=1
const = s32[] constant(42)
ROOT result = pred[] compare(gte, const), direction=EQ
}
ENTRY entry {
param.0 = f32[2] parameter(0)
param.1 = s32[] parameter(1)
while_init = (f32[2], s32[]) tuple(param.0, param.1)
ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
HloInstruction* while_op = module->entry_computation()->root_instruction();
std::vector<const HloInstruction*> aux_indices =
GetAuxiliaryLoopInductionVars(while_op);
EXPECT_EQ(aux_indices.size(), 1);
EXPECT_EQ(aux_indices[0]->opcode(), HloOpcode::kGetTupleElement);
EXPECT_EQ(aux_indices[0]->tuple_index(), 1);
}
} // namespace
} // namespace xla