[XLA] Add a RootInstructionSinker pass.
Memory space assignment requires ROOT instructions of while bodies to the the latest in the schedule. This pass sinks the ROOTs to the end of the schedule. To make sure dependencies are respected (e.g. the instructions after the ROOT may depend on the ROOT instruction), this pass will insert either tuple(gte(old_root), gte(old_root), ...) or bitcast(old_root) at the end of the computation. Note that Hlo live range, hence copy insertion, always calculate ROOTs' live ranges to be until the end of the computation, so it should be safe to move the ROOTs to the end of the schedule. PiperOrigin-RevId: 307902025 Change-Id: I2bdcc9cc660fcffbbe7611f76034c9924cf29c3d
This commit is contained in:
parent
9a5e675a70
commit
f9d6dd6269
@ -4117,6 +4117,28 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "root_instruction_sinker",
|
||||
srcs = ["root_instruction_sinker.cc"],
|
||||
hdrs = ["root_instruction_sinker.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
":tuple_util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "root_instruction_sinker_test",
|
||||
srcs = ["root_instruction_sinker_test.cc"],
|
||||
deps = [
|
||||
":hlo_matchers",
|
||||
":root_instruction_sinker",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_util",
|
||||
srcs = ["while_util.cc"],
|
||||
|
||||
@ -672,6 +672,12 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
// interval (5-6) can be allocated separately and this buffer
|
||||
// doesn't waste alternate memory space within the while loop body.
|
||||
HloComputation* while_body = use.instruction->while_body();
|
||||
// We require while body ROOTs to be the last in the schedule.
|
||||
CHECK_EQ(
|
||||
instruction_schedule.at(while_body->root_instruction()) + 1,
|
||||
instruction_schedule.at(use.instruction))
|
||||
<< "While body ROOTs need to be the last in the schedule! "
|
||||
"Please run RootInstructionSinker.";
|
||||
// Replace the use time with the parameter time so that we can
|
||||
// decide on alternate memory allocations within the while loop body
|
||||
// when we look at uses within the while loop body.
|
||||
|
||||
73
tensorflow/compiler/xla/service/root_instruction_sinker.cc
Normal file
73
tensorflow/compiler/xla/service/root_instruction_sinker.cc
Normal file
@ -0,0 +1,73 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/root_instruction_sinker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/tuple_util.h"
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Sinks the root of the given computation for tuple root types.
|
||||
void SinkTupleRoot(HloComputation* computation) {
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
CHECK(root->shape().IsTuple());
|
||||
HloInstruction* new_root = TupleUtil::Duplicate(root);
|
||||
// Add the new instructions to the schedule.
|
||||
HloInstructionSequence& sequence =
|
||||
computation->parent()->schedule().GetOrCreateSequence(computation);
|
||||
for (HloInstruction* operand : new_root->operands()) {
|
||||
sequence.push_back(operand);
|
||||
}
|
||||
sequence.push_back(new_root);
|
||||
computation->set_root_instruction(new_root);
|
||||
}
|
||||
|
||||
// Sinks the root of the given computation for not-tuple root types.
|
||||
void SinkNontupleRoot(HloComputation* computation) {
|
||||
HloInstruction* root = computation->root_instruction();
|
||||
CHECK(!root->shape().IsTuple());
|
||||
HloInstruction* new_root = computation->AddInstruction(
|
||||
HloInstruction::CreateBitcast(root->shape(), root));
|
||||
HloInstructionSequence& sequence =
|
||||
computation->parent()->schedule().GetOrCreateSequence(computation);
|
||||
sequence.push_back(new_root);
|
||||
computation->set_root_instruction(new_root);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> RootInstructionSinker::Run(HloModule* module) {
|
||||
TF_RET_CHECK(module->has_schedule());
|
||||
|
||||
bool modified = false;
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
HloInstructionSequence& sequence =
|
||||
module->schedule().GetOrCreateSequence(computation);
|
||||
if (computation->root_instruction() ==
|
||||
sequence.instructions().at(sequence.size() - 1)) {
|
||||
continue;
|
||||
}
|
||||
if (computation->root_instruction()->shape().IsTuple()) {
|
||||
SinkTupleRoot(computation);
|
||||
} else {
|
||||
SinkNontupleRoot(computation);
|
||||
}
|
||||
modified = true;
|
||||
}
|
||||
return modified;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
41
tensorflow/compiler/xla/service/root_instruction_sinker.h
Normal file
41
tensorflow/compiler/xla/service/root_instruction_sinker.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Given a scheduled HLO module, this pass sinks the ROOT of the instruction to
|
||||
// the bottom of the non-fusion computations. To avoid dependency violations of
|
||||
// moving the ROOT instruction, it creates a new ROOT instruction that looks
|
||||
// like the following:
|
||||
// - For tuple ROOT type:
|
||||
// new_root = tuple(gte(old_root), gte(old_root), ...)
|
||||
// - For non-tuple ROOT type:
|
||||
// new_root = bitcast(old_root)
|
||||
class RootInstructionSinker : public HloModulePass {
|
||||
public:
|
||||
~RootInstructionSinker() override = default;
|
||||
absl::string_view name() const override { return "root-instruction-sinker"; }
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ROOT_INSTRUCTION_SINKER_H_
|
||||
170
tensorflow/compiler/xla/service/root_instruction_sinker_test.cc
Normal file
170
tensorflow/compiler/xla/service/root_instruction_sinker_test.cc
Normal file
@ -0,0 +1,170 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/root_instruction_sinker.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
using RootInstructionSinkerTest = HloTestBase;
|
||||
|
||||
TEST_F(RootInstructionSinkerTest, TupleNoChange) {
|
||||
// ROOTS are already sunk, no change performed to the module.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule While, is_scheduled=true
|
||||
While.body {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
|
||||
constant.1 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.1)
|
||||
get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
|
||||
multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
|
||||
ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
|
||||
}
|
||||
While.condition {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(100)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY While {
|
||||
constant.3 = s32[] constant(42)
|
||||
constant.4 = s32[3]{0} constant({0, 1, 2})
|
||||
tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
|
||||
ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
|
||||
While.condition, body=While.body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto while_body =
|
||||
module->entry_computation()->root_instruction()->while_body();
|
||||
int num_body_instructions = while_body->instruction_count();
|
||||
RootInstructionSinker sinker;
|
||||
EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie());
|
||||
EXPECT_EQ(module->entry_computation()
|
||||
->root_instruction()
|
||||
->while_body()
|
||||
->instruction_count(),
|
||||
num_body_instructions);
|
||||
}
|
||||
|
||||
TEST_F(RootInstructionSinkerTest, Tuple) {
|
||||
// Sink tuple return type.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule While, is_scheduled=true
|
||||
While.body {
|
||||
loop_var.1 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.1 = s32[] get-tuple-element(loop_var.1), index=0
|
||||
constant.1 = s32[] constant(1)
|
||||
add = s32[] add(get-tuple-element.1, constant.1)
|
||||
get-tuple-element.2 = s32[3]{0} get-tuple-element(loop_var.1), index=1
|
||||
multiply = s32[3]{0} multiply(get-tuple-element.2, get-tuple-element.2)
|
||||
ROOT tuple = (s32[], s32[3]{0}) tuple(add, multiply)
|
||||
after-all = token[] after-all()
|
||||
send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1
|
||||
send-done = token[] send-done(send), channel_id=1
|
||||
}
|
||||
While.condition {
|
||||
loop_var.2 = (s32[], s32[3]{0}) parameter(0)
|
||||
get-tuple-element.3 = s32[] get-tuple-element(loop_var.2), index=0
|
||||
constant.2 = s32[] constant(100)
|
||||
ROOT less-than = pred[] compare(get-tuple-element.3, constant.2), direction=LT
|
||||
}
|
||||
ENTRY While {
|
||||
constant.3 = s32[] constant(42)
|
||||
constant.4 = s32[3]{0} constant({0, 1, 2})
|
||||
tuple.1 = (s32[], s32[3]{0}) tuple(constant.3, constant.4)
|
||||
ROOT while = (s32[], s32[3]{0}) while(tuple.1), condition=
|
||||
While.condition, body=While.body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
RootInstructionSinker sinker;
|
||||
EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie());
|
||||
auto while_body =
|
||||
module->entry_computation()->root_instruction()->while_body();
|
||||
const auto& sequence = module->schedule().sequence(while_body);
|
||||
EXPECT_EQ(sequence.instructions().at(sequence.size() - 1),
|
||||
while_body->root_instruction());
|
||||
EXPECT_THAT(while_body->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(op::Tuple()),
|
||||
op::GetTupleElement(op::Tuple())));
|
||||
}
|
||||
|
||||
TEST_F(RootInstructionSinkerTest, NontupleNoChange) {
|
||||
// ROOTS are already sunk, no change performed to the module.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Call, is_scheduled=true
|
||||
Call {
|
||||
param = s32[3]{0} parameter(0)
|
||||
ROOT multiply = s32[3]{0} multiply(param, param)
|
||||
}
|
||||
ENTRY While {
|
||||
constant.4 = s32[3]{0} constant({0, 1, 2})
|
||||
ROOT call = s32[3]{0} call(constant.4), to_apply=Call
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto called_computation =
|
||||
module->entry_computation()->root_instruction()->called_computations()[0];
|
||||
int num_instructions = called_computation->instruction_count();
|
||||
RootInstructionSinker sinker;
|
||||
EXPECT_FALSE(sinker.Run(module.get()).ValueOrDie());
|
||||
EXPECT_EQ(module->entry_computation()
|
||||
->root_instruction()
|
||||
->called_computations()[0]
|
||||
->instruction_count(),
|
||||
num_instructions);
|
||||
}
|
||||
|
||||
TEST_F(RootInstructionSinkerTest, Nontuple) {
|
||||
// Sink a non-tuple return type.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Call, is_scheduled=true
|
||||
Call {
|
||||
param = s32[3]{0} parameter(0)
|
||||
ROOT multiply = s32[3]{0} multiply(param, param)
|
||||
after-all = token[] after-all()
|
||||
send = (s32[3]{0}, u32[], token[]) send(multiply, after-all), channel_id=1
|
||||
send-done = token[] send-done(send), channel_id=1
|
||||
}
|
||||
ENTRY While {
|
||||
constant.4 = s32[3]{0} constant({0, 1, 2})
|
||||
ROOT call = s32[3]{0} call(constant.4), to_apply=Call
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
RootInstructionSinker sinker;
|
||||
EXPECT_TRUE(sinker.Run(module.get()).ValueOrDie());
|
||||
auto called_computation =
|
||||
module->entry_computation()->root_instruction()->called_computations()[0];
|
||||
const auto& sequence = module->schedule().sequence(called_computation);
|
||||
EXPECT_EQ(sequence.instructions().at(sequence.size() - 1),
|
||||
called_computation->root_instruction());
|
||||
EXPECT_THAT(called_computation->root_instruction(),
|
||||
op::Bitcast(op::Multiply()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
@ -39,6 +39,13 @@ class TupleUtil {
|
||||
static HloInstruction* AppendSuffix(
|
||||
HloInstruction* input_tuple,
|
||||
absl::Span<HloInstruction* const> trailing_values);
|
||||
|
||||
// Generates HLO instructions that duplicates the tuple by inserting
|
||||
// get-tuple-elements and a new tuple instruction. Returns the root of the
|
||||
// graph of instructions generated.
|
||||
static HloInstruction* Duplicate(HloInstruction* input_tuple) {
|
||||
return ExtractPrefix(input_tuple, input_tuple->shape().tuple_shapes_size());
|
||||
}
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user