[Resubmit][XLA] Introduce input/output alias config.
- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased. - Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference. PiperOrigin-RevId: 216737975
This commit is contained in:
parent
c304bd9bc9
commit
028410c7f4
@ -294,6 +294,7 @@ cc_library(
|
||||
srcs = [
|
||||
"dfs_hlo_visitor.cc",
|
||||
"hlo_computation.cc",
|
||||
"hlo_input_output_alias_config.cc",
|
||||
"hlo_instruction.cc",
|
||||
"hlo_instructions.cc",
|
||||
"hlo_module.cc",
|
||||
@ -308,6 +309,7 @@ cc_library(
|
||||
"hlo_clone_context.h",
|
||||
"hlo_computation.h",
|
||||
"hlo_domain_metadata.h",
|
||||
"hlo_input_output_alias_config.h",
|
||||
"hlo_instruction.h",
|
||||
"hlo_instructions.h",
|
||||
"hlo_module.h",
|
||||
@ -1268,6 +1270,25 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_input_output_alias_config_test",
|
||||
srcs = ["hlo_input_output_alias_config_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_dce",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_ordering",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_memory_scheduler",
|
||||
srcs = ["hlo_memory_scheduler.cc"],
|
||||
|
@ -239,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice(
|
||||
|
||||
void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
|
||||
int64 size) {
|
||||
VLOG(4) << "Trying to add " << buffer << " to " << this;
|
||||
VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
|
||||
CHECK(assigned_buffers_.count(&buffer) == 0)
|
||||
<< "LogicalBuffer " << buffer << " already assigned to allocation "
|
||||
<< index_;
|
||||
@ -784,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
||||
}
|
||||
}
|
||||
|
||||
if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
|
||||
const HloComputation* entry_computation =
|
||||
assignment->module_->entry_computation();
|
||||
for (auto param : entry_computation->parameter_instructions()) {
|
||||
for (auto& param_buffer :
|
||||
assignment->points_to_analysis().GetBuffersDefinedByInstruction(
|
||||
param)) {
|
||||
if (assignment->liveness().MayInterfere(*param_buffer, buffer)) {
|
||||
VLOG(4) << "Can't assign: Parameter interference with result";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the buffer is live out of the computation then it should only be
|
||||
// assigned a buffer which exactly fits the result to avoid wasting memory
|
||||
// (result buffers can have arbitrary lifetimes).
|
||||
@ -1434,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets(
|
||||
|
||||
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
|
||||
// in the same allocation (currently just supports kWhile, kCall, and
|
||||
// kConditional).
|
||||
// kConditional and input output aliasing).
|
||||
void BufferAssigner::BuildColocatedBufferSets(
|
||||
const HloModule* module, const BufferLiveness& buffer_liveness,
|
||||
const LogicalBuffer::SizeFunction& buffer_size,
|
||||
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
|
||||
const TuplePointsToAnalysis& points_to_analysis =
|
||||
buffer_liveness.points_to_analysis();
|
||||
|
||||
// Set up colocated buffer set for input and output.
|
||||
module->input_output_alias_config().ForEachAlias(
|
||||
[&](const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index) {
|
||||
std::vector<const LogicalBuffer*> colocated_set;
|
||||
AddBufferToColocatedSet(module->entry_computation()->root_instruction(),
|
||||
output_index, points_to_analysis,
|
||||
&colocated_set);
|
||||
AddBufferToColocatedSet(
|
||||
module->entry_computation()->parameter_instruction(param_number),
|
||||
param_index, points_to_analysis, &colocated_set);
|
||||
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
|
||||
});
|
||||
|
||||
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
if (computation->IsFusionComputation()) {
|
||||
continue;
|
||||
|
@ -141,6 +141,9 @@ class BufferValue {
|
||||
// operator< is required for std::set.
|
||||
bool operator<(const BufferValue& other) const { return id_ < other.id_; }
|
||||
|
||||
bool operator==(const BufferValue& other) const { return id_ == other.id_; }
|
||||
bool operator!=(const BufferValue& other) const { return id_ != other.id_; }
|
||||
|
||||
virtual string ToString() const = 0;
|
||||
|
||||
// TODO(lauj) rename LogicalBufferProto to BufferValueProto.
|
||||
|
@ -40,10 +40,12 @@ namespace {
|
||||
|
||||
using absl::StrAppend;
|
||||
|
||||
bool IsEntryParameterValue(const HloValue& value) {
|
||||
bool IsReadonlyEntryParameterValue(const HloValue& value) {
|
||||
const HloComputation* computation = value.defining_instruction()->parent();
|
||||
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
|
||||
computation == computation->parent()->entry_computation();
|
||||
computation == computation->parent()->entry_computation() &&
|
||||
!computation->parent()->input_output_alias_config().ParameterHasAlias(
|
||||
value.defining_instruction()->parameter_number(), value.index());
|
||||
}
|
||||
|
||||
bool IsConstantValue(const HloValue& value) {
|
||||
@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) {
|
||||
}
|
||||
|
||||
bool ValueIsReadOnly(const HloValue& value) {
|
||||
return IsConstantValue(value) || IsEntryParameterValue(value);
|
||||
return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
|
||||
}
|
||||
|
||||
// Data structure describing the action which should be taken on parts of a
|
||||
@ -79,8 +81,7 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
|
||||
bool ShouldCopyRootValue(const HloValue& value,
|
||||
const SpecialCaseCopyPolicy& policy) {
|
||||
if (policy.copy_parameters_and_constants) {
|
||||
return IsConstantValue(value) ||
|
||||
value.defining_instruction()->opcode() == HloOpcode::kParameter;
|
||||
return ValueIsReadOnly(value);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@ -332,6 +333,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Conservatively adds copies before root instruction of entry computation and
|
||||
// each aliased parameter to resolve interference of aliased input and output
|
||||
// buffer. We later rely on the CopyRemover to drop the unnecessary ones.
|
||||
Status AddCopiesForAliasedInputOutputs(HloModule* module) {
|
||||
HloComputation* entry = module->entry_computation();
|
||||
HloInstruction* root = entry->root_instruction();
|
||||
|
||||
ShapeTree<bool> output_indices_to_copy(root->shape());
|
||||
std::vector<ShapeTree<HloInstruction*>> copied_parameters;
|
||||
bool has_alias = false;
|
||||
for (auto* param : entry->parameter_instructions()) {
|
||||
bool param_has_alias = false;
|
||||
ShapeTree<bool> param_indices_to_copy(param->shape());
|
||||
|
||||
module->input_output_alias_config().ForEachAlias(
|
||||
[&](const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index) {
|
||||
if (param_number == param->parameter_number()) {
|
||||
param_has_alias = true;
|
||||
*(param_indices_to_copy.mutable_element(param_index)) = true;
|
||||
*(output_indices_to_copy.mutable_element(output_index)) = true;
|
||||
}
|
||||
});
|
||||
|
||||
if (!param_has_alias) {
|
||||
continue;
|
||||
}
|
||||
|
||||
has_alias = true;
|
||||
// Store a snapshot of users before DeepCopyInstruction, as
|
||||
// DeepCopyInstruction introduces new users of the instruction.
|
||||
std::vector<HloInstruction*> users = param->users();
|
||||
ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
|
||||
/*init_value=*/nullptr);
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * copied,
|
||||
entry->DeepCopyInstruction(
|
||||
param, ¶m_indices_to_copy, ¶m_copy_tree));
|
||||
for (HloInstruction* user : users) {
|
||||
TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
|
||||
}
|
||||
|
||||
copied_parameters.push_back(param_copy_tree);
|
||||
}
|
||||
|
||||
if (!has_alias) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add copies before root instruction.
|
||||
ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
|
||||
/*init_value=*/nullptr);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
|
||||
root->parent()->DeepCopyInstruction(
|
||||
root, &output_indices_to_copy, &output_copy_tree));
|
||||
|
||||
// Add control dependencies between the input/output copies.
|
||||
TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
|
||||
[&](const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& input_index) -> Status {
|
||||
HloInstruction* from =
|
||||
copied_parameters[param_number].element(input_index);
|
||||
HloInstruction* to = output_copy_tree.element(output_index);
|
||||
|
||||
TF_RET_CHECK(from != nullptr);
|
||||
TF_RET_CHECK(to != nullptr);
|
||||
TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
entry->set_root_instruction(root_copied);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Removes any control dependencies to or from the given instruction.
|
||||
Status StripControlDependenciesFrom(HloInstruction* instruction) {
|
||||
while (!instruction->control_successors().empty()) {
|
||||
@ -953,6 +1029,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1351,6 +1351,218 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
|
||||
EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, CrossingParameters) {
|
||||
// Test a case where two parameters' dataflow cross with each other while
|
||||
// input and output are aliased with same index:
|
||||
//
|
||||
// (p0 , p1)
|
||||
// | \ /|
|
||||
// | \ / |
|
||||
// alias X alias
|
||||
// | / \ |
|
||||
// | / \|
|
||||
// (p1 , p0)
|
||||
auto module = CreateNewModule();
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 4);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ParametersAliasing) {
|
||||
// Test a case where two parameters' dataflow don't interfere with each other
|
||||
// while aliased.
|
||||
//
|
||||
// (p0 , p1)
|
||||
// | |
|
||||
// | |
|
||||
// alias alias
|
||||
// | |
|
||||
// | |
|
||||
// (p0 , p1)
|
||||
auto module = CreateNewModule();
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
|
||||
// Test a case where no parameter is aliased with result. In this case, copy
|
||||
// should be added
|
||||
//
|
||||
// (p0 , p1)
|
||||
// | |
|
||||
// | |
|
||||
// | |
|
||||
// | |
|
||||
// | |
|
||||
// (p0 , p1)
|
||||
auto module = CreateNewModule();
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
|
||||
op::Copy(op::GetTupleElement(param, 1))));
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 2);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
|
||||
// Test a case where one parameter is aliased with result while another one
|
||||
// isn't.
|
||||
//
|
||||
// (p0 , p1)
|
||||
// | |
|
||||
// | |
|
||||
// alias |
|
||||
// | |
|
||||
// | |
|
||||
// (p0 , p1)
|
||||
auto module = CreateNewModule();
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Tuple(op::GetTupleElement(param, 0),
|
||||
op::Copy(op::GetTupleElement(param, 1))));
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
|
||||
// Test a case where one parameter is aliased with result while another one
|
||||
// isn't.
|
||||
//
|
||||
// +-- (p0 , p1)
|
||||
// | | |
|
||||
// | | |
|
||||
// alias Negate Negate
|
||||
// | | |
|
||||
// | | |
|
||||
// +-- (p0 , p1)
|
||||
auto module = CreateNewModule();
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
|
||||
auto negate0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
|
||||
|
||||
auto negate1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
|
||||
// Test a case where one parameter is aliased with result while another one
|
||||
// isn't.
|
||||
//
|
||||
// +-- (p0 , p1)
|
||||
// | | |
|
||||
// | | |
|
||||
// alias Negate Negate
|
||||
// | | |
|
||||
// | Add----+
|
||||
// | | |
|
||||
// +-- (p0 , p1)
|
||||
auto module = CreateNewModule();
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
|
||||
auto negate0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
|
||||
|
||||
auto negate1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
|
||||
|
||||
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape_, HloOpcode::kAdd, negate0, negate1));
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
|
||||
// Test a while instruction with a body which permutes its tuple parameter
|
||||
// elements and applies one operation to one of the elements. The addition of
|
||||
|
@ -225,6 +225,32 @@ message HloScheduleProto {
|
||||
map<int64, InstructionSequence> sequences = 1;
|
||||
}
|
||||
|
||||
message HloInputOutputAliasProto {
|
||||
// The following proto describes a pair of aliased an input
|
||||
// (described by parameter number and a ShapeIndex of the parameter)
|
||||
// and an output (described by a ShapeIndex of the root
|
||||
// instruction). For example:
|
||||
//
|
||||
// entry = {
|
||||
// output_shape_index={1},
|
||||
// parameter_number=0,
|
||||
// parameter_shape_index={1, 2},
|
||||
// }
|
||||
//
|
||||
// This entry indicates that the first paremter's {1, 2} element is
|
||||
// aliased with the {1} element of the root instruction.
|
||||
message AliasEntryProto {
|
||||
// ShapeIndex of the root hlo.
|
||||
repeated int64 output_shape_index = 1;
|
||||
// Number of the parameter in entry computation.
|
||||
int64 parameter_number = 2;
|
||||
// ShapeIndex of the parameter instruction.
|
||||
repeated int64 parameter_shape_index = 3;
|
||||
}
|
||||
|
||||
repeated AliasEntryProto entries = 1;
|
||||
}
|
||||
|
||||
// Serialization of HloModule.
|
||||
message HloModuleProto {
|
||||
string name = 1;
|
||||
@ -243,6 +269,9 @@ message HloModuleProto {
|
||||
|
||||
// The schedule for this module.
|
||||
HloScheduleProto schedule = 7;
|
||||
|
||||
// Describes alias information between inputs and outputs.
|
||||
HloInputOutputAliasProto input_output_alias = 8;
|
||||
}
|
||||
|
||||
// Serialization of LogicalBuffer.
|
||||
|
@ -59,8 +59,9 @@ class BufferValueMap {
|
||||
// construction process.
|
||||
using BufferNumber = int64;
|
||||
|
||||
explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
|
||||
: dataflow_(dataflow) {
|
||||
explicit BufferValueMap(HloModule* module,
|
||||
const HloDataflowAnalysis& dataflow)
|
||||
: module_(module), dataflow_(dataflow) {
|
||||
buffers_.reserve(dataflow_.values().size());
|
||||
value_to_buffer_number_.reserve(dataflow_.values().size());
|
||||
for (const HloValue* value : dataflow_.values()) {
|
||||
@ -171,6 +172,42 @@ class BufferValueMap {
|
||||
return value_to_buffer_number_.at(&value);
|
||||
}
|
||||
|
||||
void ComputeInputOutputAliasedBuffers(
|
||||
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
|
||||
// Get parameter value from an aliased_input object.
|
||||
const auto get_parameter_value =
|
||||
[this](const std::pair<int64, ShapeIndex>& aliased_input)
|
||||
-> const HloValue& {
|
||||
int64 param_number = aliased_input.first;
|
||||
const ShapeIndex& param_index = aliased_input.second;
|
||||
return dataflow_.GetUniqueValueAt(
|
||||
module_->entry_computation()->parameter_instruction(param_number),
|
||||
param_index);
|
||||
};
|
||||
|
||||
// If the value shows up in a root instruction, alias it with parameter
|
||||
// intruction.
|
||||
for (const HloPosition& pos : value.positions()) {
|
||||
if (pos.instruction == module_->entry_computation()->root_instruction()) {
|
||||
ShapeIndex output_index = pos.index;
|
||||
|
||||
auto aliased_input =
|
||||
module_->input_output_alias_config().GetAliasedParameter(
|
||||
output_index);
|
||||
if (aliased_input) {
|
||||
aliased_buffers->push_back(
|
||||
GetBufferForValue(get_parameter_value(*aliased_input)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If the value is parameter instruction itself, alias it with itself.
|
||||
if (value.instruction()->opcode() == HloOpcode::kParameter &&
|
||||
value.instruction()->parent() == module_->entry_computation()) {
|
||||
aliased_buffers->push_back(GetBufferForValue(value));
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeWhileAliasedBuffers(const HloValue& value,
|
||||
std::vector<BufferNumber>* aliased_buffers) {
|
||||
VLOG(3) << "Compute kWhile aliases";
|
||||
@ -278,6 +315,7 @@ class BufferValueMap {
|
||||
VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
|
||||
}
|
||||
std::vector<BufferNumber> aliased_buffers;
|
||||
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||
// Uniquify aliased buffers.
|
||||
@ -288,6 +326,8 @@ class BufferValueMap {
|
||||
return aliased_buffers;
|
||||
}
|
||||
|
||||
HloModule* module_;
|
||||
|
||||
// Dataflow analysis used to construct the buffer map.
|
||||
const HloDataflowAnalysis& dataflow_;
|
||||
|
||||
@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
/*bitcast_defines_value=*/false,
|
||||
fusion_can_share_buffer));
|
||||
|
||||
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
|
||||
BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
|
||||
buffer_map.MergeAliasedBuffers();
|
||||
|
||||
// Create a vector of HloBuffers, one for each set of values in the
|
||||
|
@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
|
||||
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
|
||||
auto negate0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
|
||||
auto negate1 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
|
||||
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
|
||||
// Cannot alias an output twice.
|
||||
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
|
||||
// parameter 0 aliased with output 1 and parameter 1 aliased with output 0.
|
||||
//
|
||||
// (p0 , p1)
|
||||
// \ /
|
||||
// \ /
|
||||
// alias X
|
||||
// / \
|
||||
// / \
|
||||
// (p0 , p1)
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto gte0 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||
auto gte1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
|
||||
// Cannot alias an output twice.
|
||||
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
// Every Ops in this graph are aliased with each other.
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
|
||||
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
|
||||
// Test a simple single while instruction can be aliased with input and output
|
||||
// of the computation.
|
||||
//
|
||||
// body((F32[], F32[]) %tuple_param):
|
||||
// %add = Add(%tuple_param{0}, %tuple_param{1})
|
||||
// return Tuple(%tuple_param{0}, %add)
|
||||
//
|
||||
// condition((F32[], F32[]) %tuple_param):
|
||||
// return Constant(false)
|
||||
//
|
||||
// entry:
|
||||
// %param1 = param1
|
||||
// %while = While(%param1, body, condition)
|
||||
// %while_1 = GTE(%while, 0)
|
||||
// %while_2 = GTE(%while, 1)
|
||||
// %negate_1 = Negate(%while_1)
|
||||
// %negate_2 = Negate(%while_2)
|
||||
// return Tuple(negate_1, negate_2)
|
||||
//
|
||||
const Shape tuple_shape =
|
||||
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||
|
||||
// Element 0 passes transparently through the body.
|
||||
auto body_builder = HloComputation::Builder("body");
|
||||
auto body_param = body_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "param"));
|
||||
auto body_element_0 = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
|
||||
auto body_element_1 = body_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
|
||||
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
|
||||
auto body_tuple = body_builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({body_element_0, add}));
|
||||
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
|
||||
|
||||
// Condition computation trivially returns a constant "false".
|
||||
auto cond_builder = HloComputation::Builder("condition");
|
||||
auto cond_param = cond_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "param"));
|
||||
cond_builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||
HloComputation* condition =
|
||||
module_->AddEmbeddedComputation(cond_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
|
||||
auto xla_while = builder.AddInstruction(
|
||||
HloInstruction::CreateWhile(tuple_shape, condition, body, param));
|
||||
auto while_element_1 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0));
|
||||
auto while_element_2 = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1));
|
||||
auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape_, HloOpcode::kNegate, while_element_1));
|
||||
auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
scalar_shape_, HloOpcode::kNegate, while_element_2));
|
||||
auto tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2}));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
EXPECT_THAT(
|
||||
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
|
||||
UnorderedElementsAre(GetValueDefinedAt(param, {1}),
|
||||
GetValueDefinedAt(xla_while, /*index=*/{1}),
|
||||
GetValueDefinedAt(body_param, {1}),
|
||||
GetValueDefinedAt(cond_param, {1}),
|
||||
GetValueDefinedAt(add),
|
||||
GetValueDefinedAt(negate_2)));
|
||||
|
||||
EXPECT_THAT(
|
||||
analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
|
||||
UnorderedElementsAre(
|
||||
HloPosition{param, {1}}, HloPosition{xla_while, {1}},
|
||||
HloPosition{while_element_2, {}}, HloPosition{body_param, {1}},
|
||||
HloPosition{body_element_1, {}}, HloPosition{add, {}},
|
||||
HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}},
|
||||
HloPosition{cond_param, {1}}, HloPosition{negate_2, {}}));
|
||||
|
||||
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, SingleCall) {
|
||||
// Test a single call of a subcomputation. The subcomputation adds its two
|
||||
// array-shaped parameters.
|
||||
|
@ -126,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
|
||||
|
||||
const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
|
||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||
CHECK(ValueIsDefinedAt(instruction, index));
|
||||
CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
|
||||
return GetUniqueValueAt(instruction, index);
|
||||
}
|
||||
|
||||
|
182
tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
Normal file
182
tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
Normal file
@ -0,0 +1,182 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
namespace xla {
|
||||
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
||||
int64 param_number,
|
||||
const ShapeIndex& param_index) {
|
||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
|
||||
<< absl::StrCat("Tring to set up alias at ", output_index.ToString(),
|
||||
" which is an invalid index for shape ",
|
||||
ShapeUtil::HumanString(alias_.shape()));
|
||||
// Output can't be aliased with multiple parameters.
|
||||
TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
|
||||
"Trying to set up output alias for param %lld at %s but failed: output "
|
||||
"index %s is already aliased with param %lld at %s",
|
||||
param_number, param_index.ToString(), output_index.ToString(),
|
||||
alias_.element(output_index)->first,
|
||||
alias_.element(output_index)->second.ToString());
|
||||
(*alias_.mutable_element(output_index)) =
|
||||
std::make_pair(param_number, param_index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
|
||||
HloInputOutputAliasProto result;
|
||||
alias_.ForEachElement(
|
||||
[&](const ShapeIndex& index,
|
||||
const absl::optional<std::pair<int64, ShapeIndex>>& data) {
|
||||
if (data) {
|
||||
HloInputOutputAliasProto::AliasEntryProto entry;
|
||||
for (int64 i : index) {
|
||||
entry.add_output_shape_index(i);
|
||||
}
|
||||
entry.set_parameter_number(data->first);
|
||||
for (int64 i : data->second) {
|
||||
entry.add_parameter_shape_index(i);
|
||||
}
|
||||
result.add_entries()->Swap(&entry);
|
||||
}
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
||||
const Shape& output_shape, const HloInputOutputAliasProto& proto) {
|
||||
HloInputOutputAliasConfig result(output_shape);
|
||||
for (const HloInputOutputAliasProto::AliasEntryProto& entry :
|
||||
proto.entries()) {
|
||||
ShapeIndex output_index(entry.output_shape_index().begin(),
|
||||
entry.output_shape_index().end());
|
||||
|
||||
int64 param_number = entry.parameter_number();
|
||||
ShapeIndex param_index(entry.parameter_shape_index().begin(),
|
||||
entry.parameter_shape_index().end());
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.SetUpAlias(output_index, param_number, param_index));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
string HloInputOutputAliasConfig::ToString() const {
|
||||
std::vector<string> pieces;
|
||||
pieces.push_back("HloInputOutputAliasConfig");
|
||||
|
||||
ForEachAlias([&](const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index) {
|
||||
pieces.push_back(absl::StrFormat(
|
||||
" OutputIndex %s is aliased with parameter %lld at %s:",
|
||||
output_index.ToString(), param_number, param_index.ToString()));
|
||||
});
|
||||
|
||||
return absl::StrJoin(pieces, "\n");
|
||||
}
|
||||
|
||||
bool HloInputOutputAliasConfig::ParameterHasAlias(
|
||||
int64 param_number, const ShapeIndex& param_index) const {
|
||||
bool output = false;
|
||||
alias_.ForEachElement(
|
||||
[&](const xla::ShapeIndex&,
|
||||
absl::optional<std::pair<int64, ShapeIndex>> alias) {
|
||||
if (alias && alias->first == param_number &&
|
||||
alias->second == param_index) {
|
||||
output = true;
|
||||
}
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
|
||||
int64 param_number, const ShapeIndex& param_index) const {
|
||||
absl::optional<ShapeIndex> output;
|
||||
alias_.ForEachElement(
|
||||
[&](const xla::ShapeIndex& output_index,
|
||||
absl::optional<std::pair<int64, ShapeIndex>> alias) {
|
||||
if (alias && alias->first == param_number &&
|
||||
alias->second == param_index) {
|
||||
output = output_index;
|
||||
}
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::optional<std::pair<int64, ShapeIndex>>
|
||||
HloInputOutputAliasConfig::GetAliasedParameter(
|
||||
const ShapeIndex& output_index) const {
|
||||
CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
|
||||
return alias_.element(output_index);
|
||||
}
|
||||
|
||||
void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
|
||||
alias_.ForEachElement(
|
||||
[&](const ShapeIndex& output_index,
|
||||
absl::optional<std::pair<int64, ShapeIndex>> aliased) {
|
||||
if (aliased) {
|
||||
fn(output_index, aliased->first, aliased->second);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
|
||||
AliasFnWithStatus fn) const {
|
||||
return alias_.ForEachElementWithStatus(
|
||||
[&](const ShapeIndex& output_index,
|
||||
absl::optional<std::pair<int64, ShapeIndex>> aliased) {
|
||||
if (aliased) {
|
||||
TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second));
|
||||
}
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
||||
Status HloInputOutputAliasConfig::Verify(const HloModule& module) const {
|
||||
std::vector<ShapeTree<bool>> param_has_seen;
|
||||
const HloComputation* entry = module.entry_computation();
|
||||
for (int64 i = 0; i < entry->num_parameters(); ++i) {
|
||||
HloInstruction* param = entry->parameter_instruction(i);
|
||||
param_has_seen.emplace_back(param->shape());
|
||||
}
|
||||
return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
|
||||
int64 param_number,
|
||||
const ShapeIndex& param_index) -> Status {
|
||||
const HloInstruction* root = entry->root_instruction();
|
||||
|
||||
const Shape& param_shape =
|
||||
entry->parameter_instruction(param_number)->shape();
|
||||
const Shape& output_shape = root->shape();
|
||||
TF_RET_CHECK(entry->num_parameters() > param_number);
|
||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index));
|
||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
|
||||
|
||||
// Check each param_number and param_index pair only show up once. No
|
||||
// input can be aliased with output buffers.
|
||||
TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false);
|
||||
|
||||
*(param_has_seen[param_number].mutable_element(param_index)) = true;
|
||||
|
||||
return Status::OK();
|
||||
});
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const HloInputOutputAliasConfig& config) {
|
||||
out << config.ToString();
|
||||
return out;
|
||||
}
|
||||
} // namespace xla
|
102
tensorflow/compiler/xla/service/hlo_input_output_alias_config.h
Normal file
102
tensorflow/compiler/xla/service/hlo_input_output_alias_config.h
Normal file
@ -0,0 +1,102 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class HloModule;
|
||||
|
||||
// This class specifies the alias map from output index to parameter number and
|
||||
// parameter index in the entry computation.
|
||||
class HloInputOutputAliasConfig {
|
||||
public:
|
||||
HloInputOutputAliasConfig() = default;
|
||||
|
||||
explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {}
|
||||
|
||||
virtual ~HloInputOutputAliasConfig() = default;
|
||||
|
||||
// Sets up alias config from `output_index` to `param_index` at
|
||||
// `param_number`.
|
||||
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index);
|
||||
|
||||
// Returns true if the given parameter is aliased with one of the output
|
||||
// buffers.
|
||||
bool ParameterHasAlias(int64 param_number,
|
||||
const ShapeIndex& param_index) const;
|
||||
|
||||
// (De)Serializes an HloInputOutoutAliasConfig to/from an
|
||||
// HloInputOutoutAliasProto.
|
||||
HloInputOutputAliasProto ToProto() const;
|
||||
|
||||
static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
|
||||
const Shape& output_shape, const HloInputOutputAliasProto& proto);
|
||||
|
||||
// Returns the output index that the given parameter and parameter index is
|
||||
// aliased with. A nullopt is returned if there is no output that is aliased
|
||||
// with the parameter number and index.
|
||||
absl::optional<ShapeIndex> GetAliasedOutput(
|
||||
int64 param_number, const ShapeIndex& param_index) const;
|
||||
|
||||
// Returns the number of parameter and index of the parameter buffer that the
|
||||
// given output buffer index is aliased with. A nullopt is returned if there
|
||||
// is no parameter is aliased with the specific output.
|
||||
absl::optional<std::pair<int64, ShapeIndex>> GetAliasedParameter(
|
||||
const ShapeIndex& output_index) const;
|
||||
|
||||
using AliasFn =
|
||||
std::function<void(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index)>;
|
||||
|
||||
// Iterates through each aliased output and input.
|
||||
void ForEachAlias(AliasFn fn) const;
|
||||
|
||||
using AliasFnWithStatus =
|
||||
std::function<Status(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index)>;
|
||||
|
||||
// Verifies that the given config is valid for the given module.
|
||||
// Specifically, the config's input and output should be in-bound and size of
|
||||
// the aliased buffers should match.
|
||||
Status Verify(const HloModule& module) const;
|
||||
|
||||
Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
private:
|
||||
// A ShapeTree which indicates the list of buffers that's expected to be
|
||||
// aliased. The key on this shape tree represents the output index. The value
|
||||
// is a pair of parameter number and index into the buffer. If the value is
|
||||
// nullopt, it means there is no parameter aliasing for this output.
|
||||
ShapeTree<absl::optional<std::pair<int64, ShapeIndex>>> alias_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out,
|
||||
const HloInputOutputAliasConfig& config);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
|
@ -0,0 +1,184 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
class HloInputOutputAliasConfigTest : public HloTestBase {
|
||||
protected:
|
||||
void expect_aliased(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index,
|
||||
const HloInputOutputAliasConfig& config) {
|
||||
absl::optional<ShapeIndex> aliased_output =
|
||||
config.GetAliasedOutput(param_number, param_index);
|
||||
|
||||
EXPECT_TRUE(aliased_output);
|
||||
EXPECT_EQ(aliased_output.value(), output_index);
|
||||
|
||||
absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
|
||||
config.GetAliasedParameter(output_index);
|
||||
|
||||
EXPECT_TRUE(aliased_param);
|
||||
EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index));
|
||||
}
|
||||
|
||||
void expect_not_aliased(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index,
|
||||
const HloInputOutputAliasConfig& config) {
|
||||
absl::optional<ShapeIndex> aliased_output =
|
||||
config.GetAliasedOutput(param_number, param_index);
|
||||
|
||||
EXPECT_FALSE(aliased_output && aliased_output == output_index);
|
||||
|
||||
absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
|
||||
config.GetAliasedParameter(output_index);
|
||||
|
||||
EXPECT_FALSE(aliased_param && aliased_param->first == param_number &&
|
||||
aliased_param->second == param_index);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) {
|
||||
const string module_str = R"(
|
||||
HloModule TEST
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT root = (f32[], f32[]) tuple(%a, %b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
|
||||
/*param_index=*/{}));
|
||||
|
||||
expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
|
||||
/*param_index=*/{}, config);
|
||||
|
||||
expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
|
||||
/*param_index=*/{}, config);
|
||||
|
||||
expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{}, config);
|
||||
}
|
||||
|
||||
TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) {
|
||||
const string module_str = R"(
|
||||
HloModule TEST
|
||||
|
||||
ENTRY main {
|
||||
param = (f32[], f32[]) parameter(0)
|
||||
gte1 = f32[] get-tuple-element(%param), index=0
|
||||
gte2 = f32[] get-tuple-element(%param), index=1
|
||||
ROOT root = (f32[], f32[]) tuple(%gte1, %gte2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{0}));
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
|
||||
/*param_index=*/{1}));
|
||||
|
||||
expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{0}, config);
|
||||
|
||||
expect_aliased(/*output_index=*/{1}, /*param_number=*/0,
|
||||
/*param_index=*/{1}, config);
|
||||
|
||||
expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
|
||||
/*param_index=*/{}, config);
|
||||
|
||||
expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{}, config);
|
||||
}
|
||||
|
||||
TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) {
|
||||
const string module_str = R"(
|
||||
HloModule TEST
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT root = (f32[], f32[]) tuple(%a, %b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{}));
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
|
||||
/*param_index=*/{}));
|
||||
|
||||
ASSERT_IS_NOT_OK(config.Verify(*module));
|
||||
}
|
||||
|
||||
TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) {
|
||||
const string module_str = R"(
|
||||
HloModule TEST
|
||||
|
||||
ENTRY main {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT root = (f32[], f32[]) tuple(%a, %b)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(module_str));
|
||||
|
||||
HloInputOutputAliasConfig config(
|
||||
module->entry_computation()->root_instruction()->shape());
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{}));
|
||||
|
||||
ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
|
||||
/*param_index=*/{}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -73,6 +73,8 @@ HloComputation* HloModule::AddComputationInternal(
|
||||
config_.SetDefaultComputationLayout(
|
||||
entry_computation_->ComputeProgramShape());
|
||||
}
|
||||
input_output_alias_config_ = HloInputOutputAliasConfig(
|
||||
entry_computation_->root_instruction()->shape());
|
||||
}
|
||||
|
||||
if (uniquify_identifiers) {
|
||||
@ -252,6 +254,9 @@ HloModuleProto HloModule::ToProto() const {
|
||||
if (has_schedule()) {
|
||||
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
|
||||
}
|
||||
|
||||
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
|
||||
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -328,6 +333,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
}
|
||||
TF_RET_CHECK(module->entry_computation_ != nullptr);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(module->input_output_alias_config_,
|
||||
HloInputOutputAliasConfig::CreateFromProto(
|
||||
result_shape, proto.input_output_alias()));
|
||||
|
||||
// Because we didn't uniquify the names or the ids, double-check that the
|
||||
// instruction and computation names and ids are unique from the proto.
|
||||
absl::flat_hash_set<string> computation_names;
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||
@ -222,6 +223,15 @@ class HloModule {
|
||||
return result;
|
||||
}
|
||||
|
||||
// input_output_alias_config indicates the list of aliased buffers that are
|
||||
// expected from the module.
|
||||
HloInputOutputAliasConfig& input_output_alias_config() {
|
||||
return input_output_alias_config_;
|
||||
}
|
||||
const HloInputOutputAliasConfig& input_output_alias_config() const {
|
||||
return input_output_alias_config_;
|
||||
}
|
||||
|
||||
// Returns an id that is unique to this module across all modules created over
|
||||
// the lifetime of this process.
|
||||
int unique_id() const { return unique_id_; }
|
||||
@ -290,6 +300,10 @@ class HloModule {
|
||||
// sequential order of instructions for each non-fusion computation in the
|
||||
// module.
|
||||
absl::optional<HloSchedule> schedule_;
|
||||
|
||||
// alias_config indicates the alias information of input/output buffers that
|
||||
// are expected from the module.
|
||||
HloInputOutputAliasConfig input_output_alias_config_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -1316,6 +1316,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
||||
TF_RETURN_IF_ERROR(module->schedule().Verify());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(*module));
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ class ShapeIndex {
|
||||
void push_back(int64 value) { indices_.push_back(value); }
|
||||
void pop_back() { indices_.pop_back(); }
|
||||
|
||||
// push_front is O(n^2), but shapes don't usually have a ton of dimensions.
|
||||
// push_front is O(n), but shapes don't usually have a ton of dimensions.
|
||||
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
|
||||
|
||||
using container_type = absl::InlinedVector<int64, 2>;
|
||||
|
Loading…
Reference in New Issue
Block a user