From 62a1fad4fa6955edfa8004e98c8dff51ecd3d18a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 22 Jan 2021 18:36:30 -0800 Subject: [PATCH] Adding support to identify a subset of possible auxiliary induction variables (AIV). Specifically, candidates are gtes, e.g., gte(param0, N). We check if the loop body plumbs the AIV through the same tuple index at root, and that ops involving AIV involve constants. op2 = op(constants, gte(param0, N), constants) op3 = op(constants, f(op2, gte(param0, N), constants) root = tuple(..., op3, ...) Further, the ops are restricted to basic math ops (+,-,*,/). PiperOrigin-RevId: 353357828 Change-Id: If2bbd7a4d4758982100c2a29e618ac40abf1de1e --- tensorflow/compiler/xla/service/BUILD | 1 + .../compiler/xla/service/hlo_reachability.cc | 17 ++ .../compiler/xla/service/hlo_reachability.h | 12 ++ .../xla/service/while_loop_analysis.cc | 142 ++++++++++++++++ .../xla/service/while_loop_analysis.h | 5 + .../xla/service/while_loop_analysis_test.cc | 152 ++++++++++++++++++ 6 files changed, 329 insertions(+) diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index aef2045642b..27462a29114 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -2489,6 +2489,7 @@ cc_library( ":hlo", ":hlo_evaluator", ":pattern_matcher", + "//tensorflow/compiler/xla/service:hlo_reachability", "@com_google_absl//absl/base", "@com_google_absl//absl/types:optional", ], diff --git a/tensorflow/compiler/xla/service/hlo_reachability.cc b/tensorflow/compiler/xla/service/hlo_reachability.cc index 9ea9a585465..259fc3ccb5b 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.cc +++ b/tensorflow/compiler/xla/service/hlo_reachability.cc @@ -73,6 +73,23 @@ void HloReachabilityMap::SetReachable(Index a, Index b) { GetBitVector(b).Set(a.v); } +std::unique_ptr HloReachabilityMap::BuildWithRestrictions( + const HloComputation* computation, + absl::FunctionRef*)> + add_dependencies) { + const auto& all = computation->MakeInstructionPostOrder(); + auto result = absl::make_unique(all); + + std::vector inputs; + for (const HloInstruction* hlo : all) { + inputs.clear(); + add_dependencies(hlo, &inputs); + result->FastSetReachabilityToUnion(inputs, hlo); + } + return result; +} + std::unique_ptr HloReachabilityMap::Build( const HloComputation* computation) { const auto& all = computation->MakeInstructionPostOrder(); diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index 15edf315560..9b981dd4a2a 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -56,6 +56,18 @@ class HloReachabilityMap { static std::unique_ptr Build( const HloComputation* computation); + // Similar to the above Build operation except that it tries to identify + // paths between instructions that do not contain control instructions + // and multiple operands, i.e., b is_reachable a == true iff + // b = f(f(f(f(f(a), constant), constant), constant). + // Further, the only ops allowed in a path are basic math operations such + // as add, sub, mul, div. + static std::unique_ptr BuildWithRestrictions( + const HloComputation* computation, + absl::FunctionRef*)> + add_dependencies); + // Set the reachability set of 'instruction' to the union of the reachability // sets of 'inputs'. Upon return, IsReachable(x, instruction) where // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.cc b/tensorflow/compiler/xla/service/while_loop_analysis.cc index ffa89b6a797..71c039d2bd7 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/while_loop_analysis.h" + #include "absl/base/casts.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_reachability.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" namespace xla { @@ -78,6 +80,146 @@ static optional GetGTEOperandIndex(const HloInstruction* instr, return tuple_idx; } +// The below function identifies a subset of all possible auxiliary +// induction variables (AIV). Specifically, candidates are gtes, e.g., +// gte(param0, N) +// The function checks if the loop body plumbs the AIV +// through the same tuple index at root, and that ops involving AIV +// involve constants. +// op2 = op(constants, gte(param0, N), constants) +// op3 = op(constants, f(op2, gte(param0, N), constants) +// op4 = op(constants, f(op3, constants) +// root = tuple(..., op4, ...) +// Further, the ops are restricted to basic math ops (+,-,*,/). +// Finally, loop invariant GTEs are excluded from AIVs. +// We can expand the ops category/nature of AIVs as needed. +std::vector GetAuxiliaryLoopInductionVars( + const HloInstruction* while_op) { + std::vector aux_ind_gte; + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_body = while_op->while_body(); + auto* while_body_param = while_body->parameter_instruction(0); + VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString(); + VLOG(2) << "the parameter instr:" << while_body_param->ToShortString(); + VLOG(2) << "the parameter user count:" << while_body_param->users().size(); + if (while_body_param == nullptr) return aux_ind_gte; + + // candidates_pairs = pair( + // operands of the root while body, + // GTE only operands that index into the same position in the parameter) + // for each candidate_pair (x, y) + // find all paths between x and y, + // each paths should satisfy the above listed criterion + // index that x and y used is added as a aux variable index + std::map extractions; + for (const HloInstruction* indx_instr : while_body_param->users()) { + if (indx_instr->opcode() != HloOpcode::kGetTupleElement) { + continue; + } + auto it = extractions.find(indx_instr->tuple_index()); + // if we find two extractions at the same index, we ignore such + // a candidate + if (it != extractions.end()) { + it->second = nullptr; + VLOG(2) << "two extractions at same index:" << indx_instr->ToString(); + } else { + extractions.insert(std::make_pair(indx_instr->tuple_index(), indx_instr)); + VLOG(2) << "inserting extraction :" << indx_instr->ToString(); + } + } + VLOG(2) << "total extractions size:" << extractions.size() << std::endl; + if (extractions.empty()) { + return aux_ind_gte; + } + + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body root is not a tuple:" << while_body_root->ToString(); + return aux_ind_gte; + } + int64 index = -1; + std::map insertions; + for (const HloInstruction* operand : while_body_root->operands()) { + index++; + if (!operand->IsConstant()) { + auto it = insertions.find(index); + if (it != insertions.end()) { + it->second = nullptr; + VLOG(2) << "two insertions at same index:" << operand->ToString(); + } else { + insertions.insert(std::make_pair(index, operand)); + VLOG(2) << "inserting insertions:" << operand->ToString(); + } + } + } + if (insertions.empty()) { + return aux_ind_gte; + } + + std::map> + candidate_pairs; + for (; index >= 0; --index) { + const HloInstruction *ext, *inst; + ext = (extractions.find(index) != extractions.end()) + ? extractions.find(index)->second + : nullptr; + inst = (insertions.find(index) != insertions.end()) + ? insertions.find(index)->second + : nullptr; + if (ext != nullptr && inst != nullptr) { + // Filter out trivial aux, i.e., extract directly to an insert. + if (ext != inst) { + candidate_pairs.insert( + std::make_pair(index, std::make_pair(ext, inst))); + } + } + } + VLOG(2) << "total candidate pairs:" << candidate_pairs.size() << std::endl; + + // Passed to ReachabilityMap to decide the type of produce-consumer edges + // along the reachability path. + const auto add_dependencies = [](const HloInstruction* hlo, + std::vector* inputs) { + HloInstruction* non_const_operand = nullptr; + int num_non_constants = 0; + for (HloInstruction* operand : hlo->operands()) { + if (!operand->IsConstant()) { + num_non_constants++; + non_const_operand = operand; + } + } + if (num_non_constants == 1 && + (hlo->opcode() == HloOpcode::kGetTupleElement || + hlo->opcode() == HloOpcode::kAdd || + hlo->opcode() == HloOpcode::kMultiply || + hlo->opcode() == HloOpcode::kDivide || + hlo->opcode() == HloOpcode::kSubtract)) { + inputs->push_back(non_const_operand); + } + }; + + std::unique_ptr hrm = + HloReachabilityMap::BuildWithRestrictions( + while_body, + absl::FunctionRef* inputs)>( + add_dependencies)); + + for (auto candidates : candidate_pairs) { + VLOG(2) << "are reachable?:" << (candidates.second.first)->ToString() + << "*************" << (candidates.second.second)->ToString() + << std::endl; + if (hrm->IsReachable(candidates.second.first, candidates.second.second)) { + aux_ind_gte.push_back(candidates.second.first); + VLOG(2) << "YES"; + } else { + VLOG(2) << "NO"; + } + } + VLOG(2) << "num auxiliary candidates :" << aux_ind_gte.size(); + return aux_ind_gte; +} + // Tries to get the tuple index of the induction variable of a while loop. // // Checks that the loop condition and body both plumb the induction variable diff --git a/tensorflow/compiler/xla/service/while_loop_analysis.h b/tensorflow/compiler/xla/service/while_loop_analysis.h index 10b64459974..9bd27a7c2ba 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis.h +++ b/tensorflow/compiler/xla/service/while_loop_analysis.h @@ -36,6 +36,11 @@ absl::optional ComputeWhileLoopTripCount( absl::optional ComputeWhileLoopTripCountUpperBound( HloInstruction *while_op); +// The below function identifies a subset of all possible auxiliary +// induction variables (AIV). Specifically, candidates are gtes, e.g., +// gte(param0, N) +std::vector GetAuxiliaryLoopInductionVars( + const HloInstruction *while_op); // Returns the tuple index of the loop induction variable if there is such an // induction variable detected. Otherwise returns nullopt. absl::optional GetLoopInductionVarTupleIdx( diff --git a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc index 5a5dc742c03..67fc456f9ae 100644 --- a/tensorflow/compiler/xla/service/while_loop_analysis_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_analysis_test.cc @@ -120,5 +120,157 @@ TEST_F(WhileLoopAnalysisTest, ExactBound) { EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); } +TEST_F(WhileLoopAnalysisTest, NoAIVNoConstChain) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[], s32[]) parameter(0) + val1 = f32[2] get-tuple-element(p_body), index=0 + val2 = s32[] get-tuple-element(p_body), index=1 + val3 = s32[] get-tuple-element(p_body), index=2 + add = s32[] add(val2, val3) + sub = s32[] subtract(add, val3) + ROOT root = (f32[2], s32[], s32[]) tuple(val1, add, sub) + } + + condition { + p_cond = (f32[2], s32[], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] compare(gte, const), direction=EQ + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + param.2 = s32[] parameter(2) + while_init = (f32[2], s32[], s32[]) tuple(param.0, param.1, param.2) + ROOT while = (f32[2], s32[], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::vector aux_indices = + GetAuxiliaryLoopInductionVars(while_op); + EXPECT_EQ(aux_indices.size(), 0); +} + +TEST_F(WhileLoopAnalysisTest, AIVMultiChain) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val1 = f32[2] get-tuple-element(p_body), index=0 + val2 = s32[] get-tuple-element(p_body), index=1 + const.1 = s32[] constant(42) + const.2 = s32[] constant(42) + const.3 = s32[] constant(42) + add = s32[] add(val2, const.1) + sub = s32[] subtract(add, const.2) + mul = s32[] multiply(sub, const.3) + ROOT root = (f32[2], s32[]) tuple(val1, mul) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] compare(gte, const), direction=EQ + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::vector aux_indices = + GetAuxiliaryLoopInductionVars(while_op); + EXPECT_EQ(aux_indices.size(), 1); + EXPECT_EQ(aux_indices[0]->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(aux_indices[0]->tuple_index(), 1); +} + +TEST_F(WhileLoopAnalysisTest, NoAIV) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val1 = f32[2] get-tuple-element(p_body), index=0 + val2 = s32[] get-tuple-element(p_body), index=1 + add = s32[] add(val2, val2) + const.1 = s32[] constant(42) + mul = s32[] multiply(add, const.1) + div = s32[] divide(mul, add) + ROOT root = (f32[2], s32[]) tuple(val1, div) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] compare(gte, const), direction=EQ + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::vector aux_indices = + GetAuxiliaryLoopInductionVars(while_op); + EXPECT_EQ(aux_indices.size(), 0); +} + +TEST_F(WhileLoopAnalysisTest, AIVNoChain) { + const char* const kHloModule = R"( + HloModule ModuleWithWhile + + body { + p_body = (f32[2], s32[]) parameter(0) + val1 = f32[2] get-tuple-element(p_body), index=0 + val2 = s32[] get-tuple-element(p_body), index=1 + const = s32[] constant(42) + add = s32[] add(val2, const) + ROOT root = (f32[2], s32[]) tuple(val1, add) + } + + condition { + p_cond = (f32[2], s32[]) parameter(0) + gte = s32[] get-tuple-element(p_cond), index=1 + const = s32[] constant(42) + ROOT result = pred[] compare(gte, const), direction=EQ + } + + ENTRY entry { + param.0 = f32[2] parameter(0) + param.1 = s32[] parameter(1) + while_init = (f32[2], s32[]) tuple(param.0, param.1) + ROOT while = (f32[2], s32[]) while(while_init), condition=condition, body=body + })"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + + HloInstruction* while_op = module->entry_computation()->root_instruction(); + std::vector aux_indices = + GetAuxiliaryLoopInductionVars(while_op); + EXPECT_EQ(aux_indices.size(), 1); + EXPECT_EQ(aux_indices[0]->opcode(), HloOpcode::kGetTupleElement); + EXPECT_EQ(aux_indices[0]->tuple_index(), 1); +} + } // namespace } // namespace xla