[XLA] Add a RootInstructionSinker pass.

Memory space assignment requires ROOT instructions of while bodies to the the
latest in the schedule. This pass sinks the ROOTs to the end of the schedule. To
make sure dependencies are respected (e.g. the instructions after the ROOT may
depend on the ROOT instruction), this pass will insert either
tuple(gte(old_root), gte(old_root), ...) or bitcast(old_root) at the end of the
computation. Note that Hlo live range, hence copy insertion, always calculate
ROOTs' live ranges to be until the end of the computation, so it should be safe
to move the ROOTs to the end of the schedule.

PiperOrigin-RevId: 307902025
Change-Id: I2bdcc9cc660fcffbbe7611f76034c9924cf29c3d
This commit is contained in:
Berkin Ilbeyi 2020-04-22 14:37:44 -07:00 committed by TensorFlower Gardener
parent 9a5e675a70
commit f9d6dd6269
6 changed files with 319 additions and 0 deletions

View File

@ -4117,6 +4117,28 @@ tf_cc_test(
],
)
cc_library(
name = "root_instruction_sinker",
srcs = ["root_instruction_sinker.cc"],
hdrs = ["root_instruction_sinker.h"],
deps = [
":hlo",
":hlo_pass",
":tuple_util",
],
)
tf_cc_test(
name = "root_instruction_sinker_test",
srcs = ["root_instruction_sinker_test.cc"],
deps = [
":hlo_matchers",
":root_instruction_sinker",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "while_util",
srcs = ["while_util.cc"],

View File

@ -672,6 +672,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
// interval (5-6) can be allocated separately and this buffer
// doesn't waste alternate memory space within the while loop body.
HloComputation* while_body = use.instruction->while_body();
// We require while body ROOTs to be the last in the schedule.
CHECK_EQ(
instruction_schedule.at(while_body->root_instruction()) + 1,
instruction_schedule.at(use.instruction))
<< "While body ROOTs need to be the last in the schedule! "
"Please run RootInstructionSinker.";
// Replace the use time with the parameter time so that we can
// decide on alternate memory allocations within the while loop body
// when we look at uses within the while loop body.

View File

@ -0,0 +1,73 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/root_instruction_sinker.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
namespace xla {
namespace {
// Sinks the root of the given computation for tuple root types.
void SinkTupleRoot(HloComputation* computation) {
HloInstruction* root = computation->root_instruction();
CHECK(root->shape().IsTuple());
HloInstruction* new_root = TupleUtil::Duplicate(root);
// Add the new instructions to the schedule.
HloInstructionSequence& sequence =
computation->parent()->schedule().GetOrCreateSequence(computation);
for (HloInstruction* operand : new_root->operands()) {
sequence.push_back(operand);
}
sequence.push_back(new_root);
computation->set_root_instruction(new_root);
}
// Sinks the root of the given computation for not-tuple root types.
void SinkNontupleRoot(HloComputation* computation) {
HloInstruction* root = computation->root_instruction();
CHECK(!root->shape().IsTuple());
HloInstruction* new_root = computation->AddInstruction(
HloInstruction::CreateBitcast(root->shape(), root));
HloInstructionSequence& sequence =
computation->parent()->schedule().GetOrCreateSequence(computation);
sequence.push_back(new_root);
computation->set_root_instruction(new_root);
}
} // namespace
StatusOr<bool> RootInstructionSinker::Run(HloModule* module) {
TF_RET_CHECK(module->has_schedule());
bool modified = false;
for (HloComputation* computation : module->MakeNonfusionComputations()) {
HloInstructionSequence& sequence =
module->schedule().GetOrCreateSequence(computation);
if (computation->root_instruction() ==
sequence.instructions().at(sequence.size() - 1)) {
continue;
}
if (computation->root_instruction()->shape().IsTuple()) {
SinkTupleRoot(computation);
} else {
SinkNontupleRoot(computation);
}
modified = true;
}
return modified;
}
} // namespace xla

View File

@ -0,0 +1,41 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// Given a scheduled HLO module, this pass sinks the ROOT of the instruction to
// the bottom of the non-fusion computations. To avoid dependency violations of
// moving the ROOT instruction, it creates a new ROOT instruction that looks
// like the following:
// - For tuple ROOT type:
// new_root = tuple(gte(old_root), gte(old_root), ...)
// - For non-tuple ROOT type:
// new_root = bitcast(old_root)
class RootInstructionSinker : public HloModulePass {
public:
~RootInstructionSinker() override = default;
absl::string_view name() const override { return "root-instruction-sinker"; }
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_

View File

@ -0,0 +1,170 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/root_instruction_sinker.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
namespace xla {
namespace {
namespace op = xla::testing::opcode_matchers;
using RootInstructionSinkerTest = HloTestBase;
TEST_F(RootInstructionSinkerTest, TupleNoChange) {
// ROOTS are already sunk, no change performed to the module.
absl::string_view hlo_string = R"(
HloModule While, is_scheduled=true
While.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
constant.1 = s32[] constant(1)
add = s32[] add(get-tuple-element.1, constant.1)
get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
}
While.condition {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(100)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY While {
constant.3 = s32[] constant(42)
constant.4 = s32[3]{0} constant({0, 1, 2})
tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
While.condition, body=While.body
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto while_body =
module->entry_computation()->root_instruction()->while_body();
int num_body_instructions = while_body->instruction_count();
RootInstructionSinker sinker;
EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie());
EXPECT_EQ(module->entry_computation()
->root_instruction()
->while_body()
->instruction_count(),
num_body_instructions);
}
TEST_F(RootInstructionSinkerTest, Tuple) {
// Sink tuple return type.
absl::string_view hlo_string = R"(
HloModule While, is_scheduled=true
While.body {
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
constant.1 = s32[] constant(1)
add = s32[] add(get-tuple-element.1, constant.1)
get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
after-all = token[] after-all()
send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1
send-done = token[] send-done(send), channel_id=1
}
While.condition {
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
constant.2 = s32[] constant(100)
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
}
ENTRY While {
constant.3 = s32[] constant(42)
constant.4 = s32[3]{0} constant({0, 1, 2})
tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
While.condition, body=While.body
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
RootInstructionSinker sinker;
EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie());
auto while_body =
module->entry_computation()->root_instruction()->while_body();
const auto& sequence = module->schedule().sequence(while_body);
EXPECT_EQ(sequence.instructions().at(sequence.size() - 1),
while_body->root_instruction());
EXPECT_THAT(while_body->root_instruction(),
op::Tuple(op::GetTupleElement(op::Tuple()),
op::GetTupleElement(op::Tuple())));
}
TEST_F(RootInstructionSinkerTest, NontupleNoChange) {
// ROOTS are already sunk, no change performed to the module.
absl::string_view hlo_string = R"(
HloModule Call, is_scheduled=true
Call {
param = s32[3]{0} parameter(0)
ROOT multiply = s32[3]{0} multiply(param, param)
}
ENTRY While {
constant.4 = s32[3]{0} constant({0, 1, 2})
ROOT call = s32[3]{0} call(constant.4), to_apply=Call
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto called_computation =
module->entry_computation()->root_instruction()->called_computations()[0];
int num_instructions = called_computation->instruction_count();
RootInstructionSinker sinker;
EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie());
EXPECT_EQ(module->entry_computation()
->root_instruction()
->called_computations()[0]
->instruction_count(),
num_instructions);
}
TEST_F(RootInstructionSinkerTest, Nontuple) {
// Sink a non-tuple return type.
absl::string_view hlo_string = R"(
HloModule Call, is_scheduled=true
Call {
param = s32[3]{0} parameter(0)
ROOT multiply = s32[3]{0} multiply(param, param)
after-all = token[] after-all()
send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1
send-done = token[] send-done(send), channel_id=1
}
ENTRY While {
constant.4 = s32[3]{0} constant({0, 1, 2})
ROOT call = s32[3]{0} call(constant.4), to_apply=Call
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
RootInstructionSinker sinker;
EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie());
auto called_computation =
module->entry_computation()->root_instruction()->called_computations()[0];
const auto& sequence = module->schedule().sequence(called_computation);
EXPECT_EQ(sequence.instructions().at(sequence.size() - 1),
called_computation->root_instruction());
EXPECT_THAT(called_computation->root_instruction(),
op::Bitcast(op::Multiply()));
}
} // namespace
} // namespace xla

View File

@ -39,6 +39,13 @@ class TupleUtil {
static HloInstruction* AppendSuffix(
HloInstruction* input_tuple,
absl::Span<HloInstruction* const> trailing_values);
// Generates HLO instructions that duplicates the tuple by inserting
// get-tuple-elements and a new tuple instruction. Returns the root of the
// graph of instructions generated.
static HloInstruction* Duplicate(HloInstruction* input_tuple) {
return ExtractPrefix(input_tuple, input_tuple->shape().tuple_shapes_size());
}
};
} // namespace xla