341 lines
12 KiB
C++
341 lines
12 KiB
C++
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
|
|
#include "absl/algorithm/container.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
class HloScheduleTest : public HloTestBase {};
|
|
|
|
TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
|
|
// Updating the schedule of an unchanged HLO module should not affect the
|
|
// schedule at all.
|
|
const string module_str = R"(
|
|
HloModule UpdateScheduleUnchanged
|
|
|
|
ENTRY main {
|
|
a = f32[] parameter(0)
|
|
b = f32[] parameter(1)
|
|
c = f32[] constant(42.0)
|
|
sum = f32[] add(a, b)
|
|
neg = f32[] negate(c)
|
|
ROOT root = f32[] multiply(sum, neg)
|
|
}
|
|
)";
|
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
ParseAndReturnVerifiedModule(module_str));
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
HloSchedule schedule,
|
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
|
}));
|
|
const auto& entry_schedule =
|
|
schedule.sequence(module->entry_computation()).instructions();
|
|
|
|
EXPECT_EQ(entry_schedule.size(), 6);
|
|
|
|
TF_ASSERT_OK(schedule.Update());
|
|
TF_ASSERT_OK(schedule.Verify());
|
|
|
|
EXPECT_EQ(entry_schedule,
|
|
schedule.sequence(module->entry_computation()).instructions());
|
|
}
|
|
|
|
TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
|
|
// Add some additional instructions to a module and verify the schedule can be
|
|
// updated.
|
|
const string module_str = R"(
|
|
HloModule UpdateScheduleWithNewInstructions
|
|
|
|
ENTRY main {
|
|
a = f32[] parameter(0)
|
|
b = f32[] parameter(1)
|
|
c = f32[] constant(42.0)
|
|
sum = f32[] add(a, b)
|
|
neg = f32[] negate(c)
|
|
ROOT root = f32[] multiply(sum, neg)
|
|
}
|
|
)";
|
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
ParseAndReturnVerifiedModule(module_str));
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
HloSchedule schedule,
|
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
|
}));
|
|
|
|
HloComputation* entry = module->entry_computation();
|
|
const Shape shape = entry->root_instruction()->shape();
|
|
HloInstruction* constant = entry->AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
|
|
HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
|
|
shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
|
|
entry->set_root_instruction(sub);
|
|
|
|
auto in_schedule = [&](const HloInstruction* hlo) {
|
|
return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
|
|
};
|
|
|
|
EXPECT_EQ(schedule.sequence(entry).size(), 6);
|
|
EXPECT_FALSE(in_schedule(constant));
|
|
EXPECT_FALSE(in_schedule(sub));
|
|
|
|
ASSERT_IS_NOT_OK(schedule.Verify());
|
|
TF_ASSERT_OK(schedule.Update());
|
|
TF_ASSERT_OK(schedule.Verify());
|
|
|
|
EXPECT_EQ(schedule.sequence(entry).size(), 8);
|
|
EXPECT_TRUE(in_schedule(constant));
|
|
EXPECT_TRUE(in_schedule(sub));
|
|
}
|
|
|
|
TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
|
|
// Add and delete some instructions from a module and verify that the schedule
|
|
// can be updated successfully.
|
|
const string module_str = R"(
|
|
HloModule UpdateScheduleWithAddedAndDeletedInstruction
|
|
|
|
ENTRY main {
|
|
a = f32[] parameter(0)
|
|
b = f32[] parameter(1)
|
|
c = f32[] constant(42.0)
|
|
sum = f32[] add(a, b)
|
|
neg = f32[] negate(c)
|
|
ROOT root = f32[] multiply(sum, neg)
|
|
}
|
|
)";
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
ParseAndReturnVerifiedModule(module_str));
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
HloSchedule schedule,
|
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
|
}));
|
|
|
|
// Set the entry root to some expression containing just a parameter and a
|
|
// constant.
|
|
HloComputation* entry = module->entry_computation();
|
|
HloInstruction* constant = entry->AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
|
|
HloInstruction* new_root = entry->AddInstruction(
|
|
HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
|
|
constant, entry->parameter_instruction(0)));
|
|
entry->set_root_instruction(new_root);
|
|
|
|
// DCE should remove everything but the parameters and the newly added code.
|
|
HloDCE dce;
|
|
TF_ASSERT_OK(dce.Run(module.get()).status());
|
|
|
|
EXPECT_EQ(schedule.sequence(entry).size(), 6);
|
|
|
|
ASSERT_IS_NOT_OK(schedule.Verify());
|
|
TF_ASSERT_OK(schedule.Update());
|
|
TF_ASSERT_OK(schedule.Verify());
|
|
|
|
EXPECT_EQ(schedule.sequence(entry).size(), 4);
|
|
}
|
|
|
|
TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
|
|
// Completely replace a module with an entirely new set of instructions and
|
|
// verify that the schedule can be updated successfully.
|
|
const string module_str = R"(
|
|
HloModule UpdateScheduleWithCompletelyReplacedModule
|
|
|
|
ENTRY main {
|
|
a = f32[] constant(42.0)
|
|
b = f32[] constant(123.0)
|
|
ROOT sum = f32[] add(a, b)
|
|
}
|
|
)";
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
ParseAndReturnVerifiedModule(module_str));
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
HloSchedule schedule,
|
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
|
}));
|
|
|
|
// Replace the entry computation with the negation of a constant.
|
|
HloComputation* entry = module->entry_computation();
|
|
HloInstruction* constant = entry->AddInstruction(
|
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
|
HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
|
|
constant->shape(), HloOpcode::kNegate, constant));
|
|
entry->set_root_instruction(new_root);
|
|
|
|
// DCE the old instructions.
|
|
HloDCE dce;
|
|
TF_ASSERT_OK(dce.Run(module.get()).status());
|
|
|
|
EXPECT_EQ(schedule.sequence(entry).size(), 3);
|
|
|
|
ASSERT_IS_NOT_OK(schedule.Verify());
|
|
TF_ASSERT_OK(schedule.Update());
|
|
TF_ASSERT_OK(schedule.Verify());
|
|
|
|
EXPECT_EQ(schedule.sequence(entry).size(), 2);
|
|
}
|
|
|
|
TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
|
|
// Create changes to more than one computation in an HLO module and verify
|
|
// that the schedule can be updated.
|
|
const string module_str = R"(
|
|
HloModule UpdateScheduleWithMultipleComputations
|
|
|
|
%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
|
|
%param.1 = (s32[], token[]) parameter(0)
|
|
%get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
|
|
%constant.1 = s32[] constant(1)
|
|
%add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
|
|
%get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
|
|
%after-all = token[] after-all(token[] %get-tuple-element.2)
|
|
ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
|
|
}
|
|
|
|
%Cond (param: (s32[], token[])) -> pred[] {
|
|
%param = (s32[], token[]) parameter(0)
|
|
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
|
%constant = s32[] constant(42)
|
|
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
|
}
|
|
|
|
ENTRY %WhileLoop () -> s32[] {
|
|
%zero = s32[] constant(0)
|
|
%init_token = token[] after-all()
|
|
%init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
|
|
%while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
|
|
ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
|
|
}
|
|
)";
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
ParseAndReturnVerifiedModule(module_str));
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
HloSchedule schedule,
|
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
|
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
|
/*pointer_size=*/sizeof(void*));
|
|
}));
|
|
|
|
const HloInstruction* xla_while =
|
|
module->entry_computation()->root_instruction()->operand(0);
|
|
HloComputation* body = xla_while->while_body();
|
|
HloComputation* cond = xla_while->while_condition();
|
|
|
|
// Negate the root of the cond.
|
|
cond->set_root_instruction(cond->AddInstruction(
|
|
HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
|
|
HloOpcode::kNot, cond->root_instruction())));
|
|
|
|
// Replace the body with a computation which just passes through its
|
|
// parameter.
|
|
body->set_root_instruction(body->parameter_instruction(0));
|
|
|
|
// DCE the dead code in the body.
|
|
HloDCE dce;
|
|
TF_ASSERT_OK(dce.Run(module.get()).status());
|
|
|
|
EXPECT_EQ(schedule.sequence(body).size(), 7);
|
|
EXPECT_EQ(schedule.sequence(cond).size(), 4);
|
|
|
|
ASSERT_IS_NOT_OK(schedule.Verify());
|
|
TF_ASSERT_OK(schedule.Update());
|
|
TF_ASSERT_OK(schedule.Verify());
|
|
|
|
EXPECT_EQ(schedule.sequence(body).size(), 1);
|
|
EXPECT_EQ(schedule.sequence(cond).size(), 5);
|
|
}
|
|
|
|
TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
|
|
// Remove computations from a module and verify the schedule can be updated.
|
|
const string module_str = R"(
|
|
HloModule UpdateScheduleWithMultipleComputations
|
|
|
|
%Body (param.1: (s32[], token[])) -> (s32[], token[]) {
|
|
%param.1 = (s32[], token[]) parameter(0)
|
|
%get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
|
|
%constant.1 = s32[] constant(1)
|
|
%add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
|
|
%get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
|
|
%after-all = token[] after-all(token[] %get-tuple-element.2)
|
|
ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
|
|
}
|
|
|
|
%Cond (param: (s32[], token[])) -> pred[] {
|
|
%param = (s32[], token[]) parameter(0)
|
|
%get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
|
|
%constant = s32[] constant(42)
|
|
ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
|
|
}
|
|
|
|
ENTRY %WhileLoop () -> s32[] {
|
|
%zero = s32[] constant(0)
|
|
%init_token = token[] after-all()
|
|
%init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
|
|
%while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
|
|
ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
|
|
}
|
|
)";
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
ParseAndReturnVerifiedModule(module_str));
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
HloSchedule schedule,
|
|
ScheduleModule(module.get(), [](const BufferValue& buffer) {
|
|
return ShapeUtil::ByteSizeOf(buffer.shape(),
|
|
/*pointer_size=*/sizeof(void*));
|
|
}));
|
|
|
|
HloInstruction* xla_while =
|
|
module->entry_computation()->root_instruction()->mutable_operand(0);
|
|
HloInstruction* init = xla_while->mutable_operand(0);
|
|
|
|
// Replace the while with its init value. The conditional and body
|
|
// computations should then be dead.
|
|
TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
|
|
|
|
// DCE the dead code in the body.
|
|
HloDCE dce;
|
|
ASSERT_EQ(module->computation_count(), 3);
|
|
TF_ASSERT_OK(dce.Run(module.get()).status());
|
|
ASSERT_EQ(module->computation_count(), 1);
|
|
|
|
ASSERT_IS_NOT_OK(schedule.Verify());
|
|
TF_ASSERT_OK(schedule.Update());
|
|
TF_ASSERT_OK(schedule.Verify());
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace xla
|