[XLA] Flatten tuples within while loops.
This is useful because it reduces the number of kTuple ops within the loop, which can be expensive. (Normally these would be removed by tuple simplifier, but that doesn't operate across computation boundaries.) But it's also useful because it unlocks further optimization opportunities. For example, it makes it easier to remove dead loop parameters. PiperOrigin-RevId: 220846222
This commit is contained in:
parent
8efd178f73
commit
6de4957aba
@ -1721,9 +1721,9 @@ cc_library(
|
||||
":call_inliner",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
":while_loop_analysis",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -1735,6 +1735,8 @@ tf_cc_test(
|
||||
name = "while_loop_simplifier_test",
|
||||
srcs = ["while_loop_simplifier_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_matchers",
|
||||
":while_loop_simplifier",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
|
@ -104,5 +104,20 @@ bool IsScalarConstant(const HloInstruction* instruction) {
|
||||
return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
|
||||
}
|
||||
|
||||
bool ContainsInstrWithOpcode(const HloComputation* comp,
|
||||
const absl::flat_hash_set<HloOpcode>& opcodes) {
|
||||
for (const auto* instr : comp->instructions()) {
|
||||
if (opcodes.count(instr->opcode())) {
|
||||
return true;
|
||||
}
|
||||
for (const HloComputation* subcomp : instr->called_computations()) {
|
||||
if (ContainsInstrWithOpcode(subcomp, opcodes)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace hlo_query
|
||||
} // namespace xla
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
||||
namespace xla {
|
||||
@ -41,6 +43,12 @@ bool AllOperandsAreConstants(const HloInstruction& instruction);
|
||||
// Returns whether the instruction is a scalar constant.
|
||||
bool IsScalarConstant(const HloInstruction* instruction);
|
||||
|
||||
// Determines whether the given computation contains an instruction with one of
|
||||
// the given opcodes. Checks both comp's instructions and the instructions of
|
||||
// any computations nested within it.
|
||||
bool ContainsInstrWithOpcode(const HloComputation* comp,
|
||||
const absl::flat_hash_set<HloOpcode>& opcodes);
|
||||
|
||||
// Returns an operand of an instruction with the given opcode. If there are
|
||||
// multiple matching operands, then the first matching operand is returned. If
|
||||
// there are no matching operands then nullptr is returned.
|
||||
|
@ -20,40 +20,14 @@ limitations under the License.
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
using absl::optional;
|
||||
|
||||
// Determines whether the given instruction is a send/recv node, or has a
|
||||
// subcomputation which contains a send/recv node.
|
||||
static bool IsOrContainsSendOrRecv(const HloInstruction* instr);
|
||||
|
||||
// Determines whether the given computation contains a send or recv node.
|
||||
static bool ContainsSendOrRecv(const HloComputation* comp) {
|
||||
for (const auto* instr : comp->instructions()) {
|
||||
if (IsOrContainsSendOrRecv(instr)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool IsOrContainsSendOrRecv(const HloInstruction* instr) {
|
||||
if (instr->opcode() == HloOpcode::kSend ||
|
||||
instr->opcode() == HloOpcode::kSendDone ||
|
||||
instr->opcode() == HloOpcode::kRecv ||
|
||||
instr->opcode() == HloOpcode::kRecvDone) {
|
||||
return true;
|
||||
}
|
||||
for (const auto& subcomp : instr->called_computations()) {
|
||||
if (ContainsSendOrRecv(subcomp)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
using hlo_query::ContainsInstrWithOpcode;
|
||||
|
||||
// Tries to remove elements in a while loop's tuple that aren't used within the
|
||||
// loop.
|
||||
@ -457,6 +431,180 @@ static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
|
||||
return changed_cond || changed_body;
|
||||
}
|
||||
|
||||
// Converts a flat list of instructions into a tuple of the desired shape. For
|
||||
// example, given a tuple shape ((x, x), x) and instructions {A, B, C}, returns
|
||||
// a tuple of value ((A, B), C).
|
||||
//
|
||||
// desired_shape must be a tuple. (This precondition allows us to return a
|
||||
// unique_ptr rather than a raw ptr.)
|
||||
static std::unique_ptr<HloInstruction> UnflattenTupleInstr(
|
||||
absl::Span<HloInstruction*> instrs, const Shape& desired_shape,
|
||||
std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
|
||||
CHECK(ShapeUtil::IsTuple(desired_shape))
|
||||
<< ShapeUtil::HumanString(desired_shape);
|
||||
|
||||
// For each child shape in `desired_shape`, slice out the correct number of
|
||||
// `instrs` and call UnflattenTupleInstr recursively. At each step we remove
|
||||
// elements from `instrs` so that it only contains instructions we have not
|
||||
// yet processed.
|
||||
std::vector<HloInstruction*> elems;
|
||||
for (int64 i = 0; i < desired_shape.tuple_shapes_size(); ++i) {
|
||||
const Shape& subshape = desired_shape.tuple_shapes(i);
|
||||
if (!ShapeUtil::IsTuple(subshape)) {
|
||||
elems.push_back(instrs[0]);
|
||||
instrs.remove_prefix(1);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Count the number of leaf nodes underneath desired_shape[i].
|
||||
int64 num_leaves = 0;
|
||||
ShapeUtil::ForEachSubshape(
|
||||
subshape, [&](const Shape& s, const ShapeIndex& /*index*/) {
|
||||
if (!ShapeUtil::IsTuple(s)) {
|
||||
++num_leaves;
|
||||
}
|
||||
});
|
||||
|
||||
std::unique_ptr<HloInstruction> subinstr =
|
||||
UnflattenTupleInstr(instrs.subspan(0, num_leaves),
|
||||
desired_shape.tuple_shapes(i), new_instrs);
|
||||
elems.push_back(subinstr.get());
|
||||
new_instrs->push_back(std::move(subinstr));
|
||||
instrs.remove_prefix(num_leaves);
|
||||
}
|
||||
return HloInstruction::CreateTuple(elems);
|
||||
}
|
||||
|
||||
// Builds a vector whose elements are the values in the flattened tuple for
|
||||
// `instr`. For example, if `instr` is a tuple of form ((A, B), C), returns the
|
||||
// vector {A, B, C} (or kGetTupleElement ops which point to A, B, and C).
|
||||
static std::vector<HloInstruction*> GetFlatTupleElems(
|
||||
HloInstruction* instr,
|
||||
std::vector<std::unique_ptr<HloInstruction>>* new_instrs) {
|
||||
const auto& shape = instr->shape();
|
||||
if (!ShapeUtil::IsTuple(shape)) {
|
||||
return {instr};
|
||||
}
|
||||
std::vector<HloInstruction*> elems;
|
||||
for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
|
||||
const Shape& subshape = shape.tuple_shapes(i);
|
||||
new_instrs->push_back(
|
||||
HloInstruction::CreateGetTupleElement(subshape, instr, i));
|
||||
auto* gte = new_instrs->back().get();
|
||||
auto flattened_subshape = GetFlatTupleElems(gte, new_instrs);
|
||||
elems.insert(elems.end(), flattened_subshape.begin(),
|
||||
flattened_subshape.end());
|
||||
}
|
||||
return elems;
|
||||
}
|
||||
|
||||
static StatusOr<bool> TryFlattenNestedTuples(HloInstruction* while_op) {
|
||||
HloModule* module = while_op->GetModule();
|
||||
HloComputation* computation = while_op->parent();
|
||||
auto* while_init = while_op->mutable_operand(0);
|
||||
auto* while_body = while_op->while_body();
|
||||
auto* while_cond = while_op->while_condition();
|
||||
auto* while_body_root = while_body->root_instruction();
|
||||
if (while_init->opcode() != HloOpcode::kTuple ||
|
||||
while_body_root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(while_cond->num_parameters() == 1);
|
||||
TF_RET_CHECK(while_body->num_parameters() == 1);
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Compatible(while_init->shape(), while_body_root->shape()));
|
||||
Shape while_shape = while_init->shape();
|
||||
if (!ShapeUtil::IsNestedTuple(while_shape)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Cowardly refuse to perform this optimization in the presence of kDomain
|
||||
// instructions, which may reference other instructions in the loop and
|
||||
// therefore make this complicated.
|
||||
if (ContainsInstrWithOpcode(while_body, {HloOpcode::kDomain}) ||
|
||||
ContainsInstrWithOpcode(while_cond, {HloOpcode::kDomain})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<Shape> flattened_shape_elems;
|
||||
ShapeUtil::ForEachSubshape(while_shape,
|
||||
[&](const Shape& s, const ShapeIndex& /*index*/) {
|
||||
if (!ShapeUtil::IsTuple(s)) {
|
||||
flattened_shape_elems.push_back(s);
|
||||
}
|
||||
});
|
||||
Shape flattened_shape = ShapeUtil::MakeTupleShape(flattened_shape_elems);
|
||||
|
||||
// `new_instrs` holds instructions created outside of a computation for
|
||||
// cloning. Elements added here just need to live until the end of the
|
||||
// relevant CloneWithReplacement call.
|
||||
std::vector<std::unique_ptr<HloInstruction>> new_instrs;
|
||||
auto add_new_instr = [&](std::unique_ptr<HloInstruction> instr) {
|
||||
new_instrs.push_back(std::move(instr));
|
||||
return new_instrs.back().get();
|
||||
};
|
||||
|
||||
auto nested = [&](HloInstruction* instr) {
|
||||
std::vector<HloInstruction*> gtes;
|
||||
const Shape& flat_shape = instr->shape();
|
||||
for (int64 i = 0; i < flat_shape.tuple_shapes_size(); ++i) {
|
||||
gtes.push_back(add_new_instr(HloInstruction::CreateGetTupleElement(
|
||||
flat_shape.tuple_shapes(i), instr, i)));
|
||||
}
|
||||
auto nested_instr =
|
||||
UnflattenTupleInstr(absl::MakeSpan(gtes), while_shape, &new_instrs);
|
||||
CHECK(ShapeUtil::Compatible(nested_instr->shape(), while_shape))
|
||||
<< ShapeUtil::HumanString(nested_instr->shape()) << " vs "
|
||||
<< ShapeUtil::HumanString(while_shape);
|
||||
return nested_instr;
|
||||
};
|
||||
|
||||
auto flattened = [&](HloInstruction* instr) {
|
||||
return HloInstruction::CreateTuple(GetFlatTupleElems(instr, &new_instrs));
|
||||
};
|
||||
|
||||
// Create a new while-condition computation, where parameter 0 has flat shape
|
||||
// but all uses of it go through the nested shape.
|
||||
std::unique_ptr<HloComputation> new_while_cond =
|
||||
while_cond->CloneWithReplacementPairs({
|
||||
while_cond->parameter_instruction(0),
|
||||
nested(add_new_instr(HloInstruction::CreateParameter(
|
||||
0, flattened_shape,
|
||||
while_cond->parameter_instruction(0)->name()))),
|
||||
});
|
||||
|
||||
// Create a new while-body computation, where parameter 0 has a flat shape and
|
||||
// all uses of it go through the nested shape, and where the root has a flat
|
||||
// shape constructed from the old nested root.
|
||||
std::unique_ptr<HloComputation> new_while_body =
|
||||
while_body->CloneWithReplacementPairs(
|
||||
{
|
||||
while_body->parameter_instruction(0),
|
||||
nested(add_new_instr(HloInstruction::CreateParameter(
|
||||
0, flattened_shape,
|
||||
while_body->parameter_instruction(0)->name()))),
|
||||
},
|
||||
{
|
||||
while_body->root_instruction(),
|
||||
flattened(add_new_instr(while_body->root_instruction()->Clone())),
|
||||
});
|
||||
|
||||
// Create the final while loop, and add any new instructions created to
|
||||
// `computation`.
|
||||
new_instrs.clear();
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
|
||||
while_op, nested(computation->AddInstruction(HloInstruction::CreateWhile(
|
||||
flattened_shape,
|
||||
module->AddEmbeddedComputation(std::move(new_while_cond)),
|
||||
module->AddEmbeddedComputation(std::move(new_while_body)),
|
||||
computation->AddInstruction(flattened(while_init)))))));
|
||||
for (auto& instr : new_instrs) {
|
||||
computation->AddInstruction(std::move(instr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(3,
|
||||
"WhileLoopSimplifier::Run(), before:\n" + module->ToString());
|
||||
@ -477,32 +625,46 @@ StatusOr<bool> WhileLoopSimplifier::Run(HloModule* module) {
|
||||
for (HloInstruction* while_op : while_ops) {
|
||||
// We can't remove while loops that contain send/recv nodes, because we rely
|
||||
// on the particular loop structure around the node matching on the send and
|
||||
// recv sides. Removing dead while params requires us to remove the loop
|
||||
// recv sides. Other while simplifications require us to remove the loop
|
||||
// and replace it with a new one, so we can't do that either.
|
||||
if (ContainsSendOrRecv(while_op->while_body()) ||
|
||||
ContainsSendOrRecv(while_op->while_condition())) {
|
||||
if (ContainsInstrWithOpcode(while_op->while_body(),
|
||||
{HloOpcode::kSend, HloOpcode::kSendDone,
|
||||
HloOpcode::kRecv, HloOpcode::kRecvDone}) ||
|
||||
ContainsInstrWithOpcode(while_op->while_condition(),
|
||||
{HloOpcode::kSend, HloOpcode::kSendDone,
|
||||
HloOpcode::kRecv, HloOpcode::kRecvDone})) {
|
||||
VLOG(2) << "Not attempting to simplify while loop because it contains a "
|
||||
"send/recv node: "
|
||||
<< while_op->ToShortString();
|
||||
continue;
|
||||
}
|
||||
|
||||
StatusOr<bool> result = TryPropagateConstant(while_op);
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
changed |= result.ValueOrDie();
|
||||
TF_ASSIGN_OR_RETURN(bool result, TryPropagateConstant(while_op));
|
||||
changed |= result;
|
||||
|
||||
result = TryRemoveWhileLoop(while_op);
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
if (result.ValueOrDie()) {
|
||||
changed = true;
|
||||
// Don't try to remove dead while params after successfully removing the
|
||||
// while loop -- that would result in use-after-free nastiness.
|
||||
TF_ASSIGN_OR_RETURN(result, TryRemoveWhileLoop(while_op));
|
||||
changed |= result;
|
||||
if (result) {
|
||||
// Don't continue simplifying after successfully removing the while loop
|
||||
// -- that would result in use-after-free nastiness.
|
||||
continue;
|
||||
}
|
||||
|
||||
result = TryRemoveDeadWhileParams(while_op);
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
changed |= result.ValueOrDie();
|
||||
TF_ASSIGN_OR_RETURN(result, TryFlattenNestedTuples(while_op));
|
||||
changed |= result;
|
||||
if (result) {
|
||||
// Successfully flattening nested tuples results in us cloning and
|
||||
// replacing the while loop, meaning that `while_op` is no longer valid.
|
||||
continue;
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(result, TryRemoveDeadWhileParams(while_op));
|
||||
changed |= result;
|
||||
if (result) {
|
||||
// Successfully removing dead while params results in us cloning and
|
||||
// replacing the while loop, meaning that `while_op` is no longer valid.
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
XLA_VLOG_LINES(3,
|
||||
|
@ -25,11 +25,22 @@ namespace xla {
|
||||
// HLO pass that makes the following transformations on while loops:
|
||||
//
|
||||
// - A while loop with static trip count of 0 is deleted.
|
||||
//
|
||||
// - A while loop with static trip count of 1 is replaced by its body (sans
|
||||
// loop).
|
||||
//
|
||||
// - Elements of a while loop's tuple that the loop doesn't use are removed
|
||||
// from the tuple.
|
||||
//
|
||||
// - If the while loop's parameter is a nested tuple, it's flattened to a
|
||||
// single-level tuple. This is good because it usually reduces the number of
|
||||
// kTuple instructions, but also because it unlocks additional optimizations
|
||||
// (e.g. removing unused loop parameters).
|
||||
//
|
||||
// Flattening nested while loop tuples adds a whole mess of likely unnecessary
|
||||
// kGetTupleElement and kTuple operations to the graph. We expect that tuple
|
||||
// simplifier will be run afterwards.
|
||||
//
|
||||
class WhileLoopSimplifier : public HloModulePass {
|
||||
public:
|
||||
~WhileLoopSimplifier() override {}
|
||||
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
|
||||
@ -509,5 +511,62 @@ TEST_F(WhileLoopSimplifierTest, LoopWithArrayConstantNotSimplified) {
|
||||
EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopSimplifierTest, FlattenNestedTuple) {
|
||||
const string hlo_string = R"(
|
||||
HloModule Test
|
||||
Body {
|
||||
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
|
||||
ta = (s32[1]) get-tuple-element(param), index=0
|
||||
a = s32[1] get-tuple-element(ta), index=0
|
||||
a.1 = s32[1] add(a, a)
|
||||
tbcd = (s32[2], s32[3], (s32[4])) get-tuple-element(param), index=1
|
||||
ROOT tuple = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
|
||||
}
|
||||
Cond {
|
||||
param = ((s32[1]), (s32[2], s32[3], (s32[4]))) parameter(0)
|
||||
ROOT cond = pred[] constant(true)
|
||||
}
|
||||
ENTRY Loop {
|
||||
a = s32[1] constant({0})
|
||||
b = s32[2] constant({0,1})
|
||||
c = s32[3] constant({0,1,2})
|
||||
d = s32[4] constant({0,1,2,3})
|
||||
ta = (s32[1]) tuple(a)
|
||||
td = (s32[4]) tuple(d)
|
||||
tbcd = (s32[2], s32[3], (s32[4])) tuple(b, c, td)
|
||||
init = ((s32[1]), (s32[2], s32[3], (s32[4]))) tuple(ta, tbcd)
|
||||
ROOT while = ((s32[1]), (s32[2], s32[3], (s32[4]))) while(init),
|
||||
condition=Cond, body=Body
|
||||
})";
|
||||
|
||||
ParseAndVerifyModule(hlo_string);
|
||||
EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie());
|
||||
// DCE away the old loop so there's just one while loop in the module, making
|
||||
// it easy to find.
|
||||
EXPECT_TRUE(HloDCE().Run(&module()).ok());
|
||||
|
||||
const auto& instrs = module().entry_computation()->instructions();
|
||||
HloInstruction* new_while =
|
||||
*absl::c_find_if(instrs, [](const HloInstruction* instr) {
|
||||
return instr->opcode() == HloOpcode::kWhile;
|
||||
});
|
||||
Shape flat_tuple =
|
||||
ShapeUtil::ParseShapeString("(s32[1], s32[2], s32[3], s32[4])")
|
||||
.ValueOrDie();
|
||||
SCOPED_TRACE(module().ToString());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(new_while->shape(), flat_tuple));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_body()->root_instruction()->shape(), flat_tuple));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_body()->parameter_instruction(0)->shape(), flat_tuple));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
new_while->while_condition()->parameter_instruction(0)->shape(),
|
||||
flat_tuple));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
module().entry_computation()->root_instruction()->shape(),
|
||||
ShapeUtil::ParseShapeString("((s32[1]), (s32[2], s32[3], (s32[4])))")
|
||||
.ValueOrDie()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user