[Resubmit][XLA] Introduce input/output alias config.
- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased. - Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference. PiperOrigin-RevId: 216737975
This commit is contained in:
parent
c304bd9bc9
commit
028410c7f4
@ -294,6 +294,7 @@ cc_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"dfs_hlo_visitor.cc",
|
"dfs_hlo_visitor.cc",
|
||||||
"hlo_computation.cc",
|
"hlo_computation.cc",
|
||||||
|
"hlo_input_output_alias_config.cc",
|
||||||
"hlo_instruction.cc",
|
"hlo_instruction.cc",
|
||||||
"hlo_instructions.cc",
|
"hlo_instructions.cc",
|
||||||
"hlo_module.cc",
|
"hlo_module.cc",
|
||||||
@ -308,6 +309,7 @@ cc_library(
|
|||||||
"hlo_clone_context.h",
|
"hlo_clone_context.h",
|
||||||
"hlo_computation.h",
|
"hlo_computation.h",
|
||||||
"hlo_domain_metadata.h",
|
"hlo_domain_metadata.h",
|
||||||
|
"hlo_input_output_alias_config.h",
|
||||||
"hlo_instruction.h",
|
"hlo_instruction.h",
|
||||||
"hlo_instructions.h",
|
"hlo_instructions.h",
|
||||||
"hlo_module.h",
|
"hlo_module.h",
|
||||||
@ -1268,6 +1270,25 @@ tf_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "hlo_input_output_alias_config_test",
|
||||||
|
srcs = ["hlo_input_output_alias_config_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
":hlo_dce",
|
||||||
|
":hlo_memory_scheduler",
|
||||||
|
":hlo_ordering",
|
||||||
|
":hlo_parser",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"@com_google_absl//absl/algorithm:container",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "hlo_memory_scheduler",
|
name = "hlo_memory_scheduler",
|
||||||
srcs = ["hlo_memory_scheduler.cc"],
|
srcs = ["hlo_memory_scheduler.cc"],
|
||||||
|
@ -239,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice(
|
|||||||
|
|
||||||
void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
|
void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
|
||||||
int64 size) {
|
int64 size) {
|
||||||
VLOG(4) << "Trying to add " << buffer << " to " << this;
|
VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
|
||||||
CHECK(assigned_buffers_.count(&buffer) == 0)
|
CHECK(assigned_buffers_.count(&buffer) == 0)
|
||||||
<< "LogicalBuffer " << buffer << " already assigned to allocation "
|
<< "LogicalBuffer " << buffer << " already assigned to allocation "
|
||||||
<< index_;
|
<< index_;
|
||||||
@ -784,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
|
|
||||||
const HloComputation* entry_computation =
|
|
||||||
assignment->module_->entry_computation();
|
|
||||||
for (auto param : entry_computation->parameter_instructions()) {
|
|
||||||
for (auto& param_buffer :
|
|
||||||
assignment->points_to_analysis().GetBuffersDefinedByInstruction(
|
|
||||||
param)) {
|
|
||||||
if (assignment->liveness().MayInterfere(*param_buffer, buffer)) {
|
|
||||||
VLOG(4) << "Can't assign: Parameter interference with result";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the buffer is live out of the computation then it should only be
|
// If the buffer is live out of the computation then it should only be
|
||||||
// assigned a buffer which exactly fits the result to avoid wasting memory
|
// assigned a buffer which exactly fits the result to avoid wasting memory
|
||||||
// (result buffers can have arbitrary lifetimes).
|
// (result buffers can have arbitrary lifetimes).
|
||||||
@ -1434,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets(
|
|||||||
|
|
||||||
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
|
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
|
||||||
// in the same allocation (currently just supports kWhile, kCall, and
|
// in the same allocation (currently just supports kWhile, kCall, and
|
||||||
// kConditional).
|
// kConditional and input output aliasing).
|
||||||
void BufferAssigner::BuildColocatedBufferSets(
|
void BufferAssigner::BuildColocatedBufferSets(
|
||||||
const HloModule* module, const BufferLiveness& buffer_liveness,
|
const HloModule* module, const BufferLiveness& buffer_liveness,
|
||||||
const LogicalBuffer::SizeFunction& buffer_size,
|
const LogicalBuffer::SizeFunction& buffer_size,
|
||||||
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
|
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
|
||||||
const TuplePointsToAnalysis& points_to_analysis =
|
const TuplePointsToAnalysis& points_to_analysis =
|
||||||
buffer_liveness.points_to_analysis();
|
buffer_liveness.points_to_analysis();
|
||||||
|
|
||||||
|
// Set up colocated buffer set for input and output.
|
||||||
|
module->input_output_alias_config().ForEachAlias(
|
||||||
|
[&](const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index) {
|
||||||
|
std::vector<const LogicalBuffer*> colocated_set;
|
||||||
|
AddBufferToColocatedSet(module->entry_computation()->root_instruction(),
|
||||||
|
output_index, points_to_analysis,
|
||||||
|
&colocated_set);
|
||||||
|
AddBufferToColocatedSet(
|
||||||
|
module->entry_computation()->parameter_instruction(param_number),
|
||||||
|
param_index, points_to_analysis, &colocated_set);
|
||||||
|
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
|
||||||
|
});
|
||||||
|
|
||||||
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
|
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||||
if (computation->IsFusionComputation()) {
|
if (computation->IsFusionComputation()) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -141,6 +141,9 @@ class BufferValue {
|
|||||||
// operator< is required for std::set.
|
// operator< is required for std::set.
|
||||||
bool operator<(const BufferValue& other) const { return id_ < other.id_; }
|
bool operator<(const BufferValue& other) const { return id_ < other.id_; }
|
||||||
|
|
||||||
|
bool operator==(const BufferValue& other) const { return id_ == other.id_; }
|
||||||
|
bool operator!=(const BufferValue& other) const { return id_ != other.id_; }
|
||||||
|
|
||||||
virtual string ToString() const = 0;
|
virtual string ToString() const = 0;
|
||||||
|
|
||||||
// TODO(lauj) rename LogicalBufferProto to BufferValueProto.
|
// TODO(lauj) rename LogicalBufferProto to BufferValueProto.
|
||||||
|
@ -40,10 +40,12 @@ namespace {
|
|||||||
|
|
||||||
using absl::StrAppend;
|
using absl::StrAppend;
|
||||||
|
|
||||||
bool IsEntryParameterValue(const HloValue& value) {
|
bool IsReadonlyEntryParameterValue(const HloValue& value) {
|
||||||
const HloComputation* computation = value.defining_instruction()->parent();
|
const HloComputation* computation = value.defining_instruction()->parent();
|
||||||
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
|
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
|
||||||
computation == computation->parent()->entry_computation();
|
computation == computation->parent()->entry_computation() &&
|
||||||
|
!computation->parent()->input_output_alias_config().ParameterHasAlias(
|
||||||
|
value.defining_instruction()->parameter_number(), value.index());
|
||||||
}
|
}
|
||||||
|
|
||||||
bool IsConstantValue(const HloValue& value) {
|
bool IsConstantValue(const HloValue& value) {
|
||||||
@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool ValueIsReadOnly(const HloValue& value) {
|
bool ValueIsReadOnly(const HloValue& value) {
|
||||||
return IsConstantValue(value) || IsEntryParameterValue(value);
|
return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Data structure describing the action which should be taken on parts of a
|
// Data structure describing the action which should be taken on parts of a
|
||||||
@ -79,8 +81,7 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
|
|||||||
bool ShouldCopyRootValue(const HloValue& value,
|
bool ShouldCopyRootValue(const HloValue& value,
|
||||||
const SpecialCaseCopyPolicy& policy) {
|
const SpecialCaseCopyPolicy& policy) {
|
||||||
if (policy.copy_parameters_and_constants) {
|
if (policy.copy_parameters_and_constants) {
|
||||||
return IsConstantValue(value) ||
|
return ValueIsReadOnly(value);
|
||||||
value.defining_instruction()->opcode() == HloOpcode::kParameter;
|
|
||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -332,6 +333,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Conservatively adds copies before root instruction of entry computation and
|
||||||
|
// each aliased parameter to resolve interference of aliased input and output
|
||||||
|
// buffer. We later rely on the CopyRemover to drop the unnecessary ones.
|
||||||
|
Status AddCopiesForAliasedInputOutputs(HloModule* module) {
|
||||||
|
HloComputation* entry = module->entry_computation();
|
||||||
|
HloInstruction* root = entry->root_instruction();
|
||||||
|
|
||||||
|
ShapeTree<bool> output_indices_to_copy(root->shape());
|
||||||
|
std::vector<ShapeTree<HloInstruction*>> copied_parameters;
|
||||||
|
bool has_alias = false;
|
||||||
|
for (auto* param : entry->parameter_instructions()) {
|
||||||
|
bool param_has_alias = false;
|
||||||
|
ShapeTree<bool> param_indices_to_copy(param->shape());
|
||||||
|
|
||||||
|
module->input_output_alias_config().ForEachAlias(
|
||||||
|
[&](const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index) {
|
||||||
|
if (param_number == param->parameter_number()) {
|
||||||
|
param_has_alias = true;
|
||||||
|
*(param_indices_to_copy.mutable_element(param_index)) = true;
|
||||||
|
*(output_indices_to_copy.mutable_element(output_index)) = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!param_has_alias) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
has_alias = true;
|
||||||
|
// Store a snapshot of users before DeepCopyInstruction, as
|
||||||
|
// DeepCopyInstruction introduces new users of the instruction.
|
||||||
|
std::vector<HloInstruction*> users = param->users();
|
||||||
|
ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
|
||||||
|
/*init_value=*/nullptr);
|
||||||
|
TF_ASSIGN_OR_RETURN(HloInstruction * copied,
|
||||||
|
entry->DeepCopyInstruction(
|
||||||
|
param, ¶m_indices_to_copy, ¶m_copy_tree));
|
||||||
|
for (HloInstruction* user : users) {
|
||||||
|
TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
|
||||||
|
}
|
||||||
|
|
||||||
|
copied_parameters.push_back(param_copy_tree);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!has_alias) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add copies before root instruction.
|
||||||
|
ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
|
||||||
|
/*init_value=*/nullptr);
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
|
||||||
|
root->parent()->DeepCopyInstruction(
|
||||||
|
root, &output_indices_to_copy, &output_copy_tree));
|
||||||
|
|
||||||
|
// Add control dependencies between the input/output copies.
|
||||||
|
TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
|
||||||
|
[&](const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& input_index) -> Status {
|
||||||
|
HloInstruction* from =
|
||||||
|
copied_parameters[param_number].element(input_index);
|
||||||
|
HloInstruction* to = output_copy_tree.element(output_index);
|
||||||
|
|
||||||
|
TF_RET_CHECK(from != nullptr);
|
||||||
|
TF_RET_CHECK(to != nullptr);
|
||||||
|
TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
|
||||||
|
entry->set_root_instruction(root_copied);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Removes any control dependencies to or from the given instruction.
|
// Removes any control dependencies to or from the given instruction.
|
||||||
Status StripControlDependenciesFrom(HloInstruction* instruction) {
|
Status StripControlDependenciesFrom(HloInstruction* instruction) {
|
||||||
while (!instruction->control_successors().empty()) {
|
while (!instruction->control_successors().empty()) {
|
||||||
@ -953,6 +1029,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1351,6 +1351,218 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
|
|||||||
EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
|
EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, CrossingParameters) {
|
||||||
|
// Test a case where two parameters' dataflow cross with each other while
|
||||||
|
// input and output are aliased with same index:
|
||||||
|
//
|
||||||
|
// (p0 , p1)
|
||||||
|
// | \ /|
|
||||||
|
// | \ / |
|
||||||
|
// alias X alias
|
||||||
|
// | / \ |
|
||||||
|
// | / \|
|
||||||
|
// (p1 , p0)
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||||
|
InsertCopies(module.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(CountCopies(*module), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, ParametersAliasing) {
|
||||||
|
// Test a case where two parameters' dataflow don't interfere with each other
|
||||||
|
// while aliased.
|
||||||
|
//
|
||||||
|
// (p0 , p1)
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// alias alias
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// (p0 , p1)
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||||
|
InsertCopies(module.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(CountCopies(*module), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
|
||||||
|
// Test a case where no parameter is aliased with result. In this case, copy
|
||||||
|
// should be added
|
||||||
|
//
|
||||||
|
// (p0 , p1)
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// (p0 , p1)
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
InsertCopies(module.get());
|
||||||
|
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
|
||||||
|
op::Copy(op::GetTupleElement(param, 1))));
|
||||||
|
|
||||||
|
EXPECT_EQ(CountCopies(*module), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
|
||||||
|
// Test a case where one parameter is aliased with result while another one
|
||||||
|
// isn't.
|
||||||
|
//
|
||||||
|
// (p0 , p1)
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// alias |
|
||||||
|
// | |
|
||||||
|
// | |
|
||||||
|
// (p0 , p1)
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
InsertCopies(module.get());
|
||||||
|
|
||||||
|
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||||
|
op::Tuple(op::GetTupleElement(param, 0),
|
||||||
|
op::Copy(op::GetTupleElement(param, 1))));
|
||||||
|
|
||||||
|
EXPECT_EQ(CountCopies(*module), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
|
||||||
|
// Test a case where one parameter is aliased with result while another one
|
||||||
|
// isn't.
|
||||||
|
//
|
||||||
|
// +-- (p0 , p1)
|
||||||
|
// | | |
|
||||||
|
// | | |
|
||||||
|
// alias Negate Negate
|
||||||
|
// | | |
|
||||||
|
// | | |
|
||||||
|
// +-- (p0 , p1)
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
|
||||||
|
auto negate0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
|
||||||
|
|
||||||
|
auto negate1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
InsertCopies(module.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(CountCopies(*module), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
|
||||||
|
// Test a case where one parameter is aliased with result while another one
|
||||||
|
// isn't.
|
||||||
|
//
|
||||||
|
// +-- (p0 , p1)
|
||||||
|
// | | |
|
||||||
|
// | | |
|
||||||
|
// alias Negate Negate
|
||||||
|
// | | |
|
||||||
|
// | Add----+
|
||||||
|
// | | |
|
||||||
|
// +-- (p0 , p1)
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
|
||||||
|
auto negate0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
|
||||||
|
|
||||||
|
auto negate1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
|
||||||
|
|
||||||
|
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
scalar_shape_, HloOpcode::kAdd, negate0, negate1));
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
|
||||||
|
module->AddEntryComputation(builder.Build());
|
||||||
|
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
InsertCopies(module.get());
|
||||||
|
|
||||||
|
EXPECT_EQ(CountCopies(*module), 0);
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
|
TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
|
||||||
// Test a while instruction with a body which permutes its tuple parameter
|
// Test a while instruction with a body which permutes its tuple parameter
|
||||||
// elements and applies one operation to one of the elements. The addition of
|
// elements and applies one operation to one of the elements. The addition of
|
||||||
|
@ -225,6 +225,32 @@ message HloScheduleProto {
|
|||||||
map<int64, InstructionSequence> sequences = 1;
|
map<int64, InstructionSequence> sequences = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message HloInputOutputAliasProto {
|
||||||
|
// The following proto describes a pair of aliased an input
|
||||||
|
// (described by parameter number and a ShapeIndex of the parameter)
|
||||||
|
// and an output (described by a ShapeIndex of the root
|
||||||
|
// instruction). For example:
|
||||||
|
//
|
||||||
|
// entry = {
|
||||||
|
// output_shape_index={1},
|
||||||
|
// parameter_number=0,
|
||||||
|
// parameter_shape_index={1, 2},
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// This entry indicates that the first paremter's {1, 2} element is
|
||||||
|
// aliased with the {1} element of the root instruction.
|
||||||
|
message AliasEntryProto {
|
||||||
|
// ShapeIndex of the root hlo.
|
||||||
|
repeated int64 output_shape_index = 1;
|
||||||
|
// Number of the parameter in entry computation.
|
||||||
|
int64 parameter_number = 2;
|
||||||
|
// ShapeIndex of the parameter instruction.
|
||||||
|
repeated int64 parameter_shape_index = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
repeated AliasEntryProto entries = 1;
|
||||||
|
}
|
||||||
|
|
||||||
// Serialization of HloModule.
|
// Serialization of HloModule.
|
||||||
message HloModuleProto {
|
message HloModuleProto {
|
||||||
string name = 1;
|
string name = 1;
|
||||||
@ -243,6 +269,9 @@ message HloModuleProto {
|
|||||||
|
|
||||||
// The schedule for this module.
|
// The schedule for this module.
|
||||||
HloScheduleProto schedule = 7;
|
HloScheduleProto schedule = 7;
|
||||||
|
|
||||||
|
// Describes alias information between inputs and outputs.
|
||||||
|
HloInputOutputAliasProto input_output_alias = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialization of LogicalBuffer.
|
// Serialization of LogicalBuffer.
|
||||||
|
@ -59,8 +59,9 @@ class BufferValueMap {
|
|||||||
// construction process.
|
// construction process.
|
||||||
using BufferNumber = int64;
|
using BufferNumber = int64;
|
||||||
|
|
||||||
explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
|
explicit BufferValueMap(HloModule* module,
|
||||||
: dataflow_(dataflow) {
|
const HloDataflowAnalysis& dataflow)
|
||||||
|
: module_(module), dataflow_(dataflow) {
|
||||||
buffers_.reserve(dataflow_.values().size());
|
buffers_.reserve(dataflow_.values().size());
|
||||||
value_to_buffer_number_.reserve(dataflow_.values().size());
|
value_to_buffer_number_.reserve(dataflow_.values().size());
|
||||||
for (const HloValue* value : dataflow_.values()) {
|
for (const HloValue* value : dataflow_.values()) {
|
||||||
@ -171,6 +172,42 @@ class BufferValueMap {
|
|||||||
return value_to_buffer_number_.at(&value);
|
return value_to_buffer_number_.at(&value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ComputeInputOutputAliasedBuffers(
|
||||||
|
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
|
||||||
|
// Get parameter value from an aliased_input object.
|
||||||
|
const auto get_parameter_value =
|
||||||
|
[this](const std::pair<int64, ShapeIndex>& aliased_input)
|
||||||
|
-> const HloValue& {
|
||||||
|
int64 param_number = aliased_input.first;
|
||||||
|
const ShapeIndex& param_index = aliased_input.second;
|
||||||
|
return dataflow_.GetUniqueValueAt(
|
||||||
|
module_->entry_computation()->parameter_instruction(param_number),
|
||||||
|
param_index);
|
||||||
|
};
|
||||||
|
|
||||||
|
// If the value shows up in a root instruction, alias it with parameter
|
||||||
|
// intruction.
|
||||||
|
for (const HloPosition& pos : value.positions()) {
|
||||||
|
if (pos.instruction == module_->entry_computation()->root_instruction()) {
|
||||||
|
ShapeIndex output_index = pos.index;
|
||||||
|
|
||||||
|
auto aliased_input =
|
||||||
|
module_->input_output_alias_config().GetAliasedParameter(
|
||||||
|
output_index);
|
||||||
|
if (aliased_input) {
|
||||||
|
aliased_buffers->push_back(
|
||||||
|
GetBufferForValue(get_parameter_value(*aliased_input)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the value is parameter instruction itself, alias it with itself.
|
||||||
|
if (value.instruction()->opcode() == HloOpcode::kParameter &&
|
||||||
|
value.instruction()->parent() == module_->entry_computation()) {
|
||||||
|
aliased_buffers->push_back(GetBufferForValue(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ComputeWhileAliasedBuffers(const HloValue& value,
|
void ComputeWhileAliasedBuffers(const HloValue& value,
|
||||||
std::vector<BufferNumber>* aliased_buffers) {
|
std::vector<BufferNumber>* aliased_buffers) {
|
||||||
VLOG(3) << "Compute kWhile aliases";
|
VLOG(3) << "Compute kWhile aliases";
|
||||||
@ -278,6 +315,7 @@ class BufferValueMap {
|
|||||||
VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
|
VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
|
||||||
}
|
}
|
||||||
std::vector<BufferNumber> aliased_buffers;
|
std::vector<BufferNumber> aliased_buffers;
|
||||||
|
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
|
||||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||||
// Uniquify aliased buffers.
|
// Uniquify aliased buffers.
|
||||||
@ -288,6 +326,8 @@ class BufferValueMap {
|
|||||||
return aliased_buffers;
|
return aliased_buffers;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloModule* module_;
|
||||||
|
|
||||||
// Dataflow analysis used to construct the buffer map.
|
// Dataflow analysis used to construct the buffer map.
|
||||||
const HloDataflowAnalysis& dataflow_;
|
const HloDataflowAnalysis& dataflow_;
|
||||||
|
|
||||||
@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
|||||||
/*bitcast_defines_value=*/false,
|
/*bitcast_defines_value=*/false,
|
||||||
fusion_can_share_buffer));
|
fusion_can_share_buffer));
|
||||||
|
|
||||||
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
|
BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
|
||||||
buffer_map.MergeAliasedBuffers();
|
buffer_map.MergeAliasedBuffers();
|
||||||
|
|
||||||
// Create a vector of HloBuffers, one for each set of values in the
|
// Create a vector of HloBuffers, one for each set of values in the
|
||||||
|
@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
|
|||||||
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
|
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
|
||||||
|
auto negate0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
|
||||||
|
auto negate1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
|
||||||
|
|
||||||
|
auto tuple =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
|
||||||
|
module_->AddEntryComputation(builder.Build());
|
||||||
|
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||||
|
|
||||||
|
// Cannot alias an output twice.
|
||||||
|
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
|
||||||
|
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||||
|
|
||||||
|
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
|
||||||
|
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
|
||||||
|
|
||||||
|
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
|
||||||
|
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
|
||||||
|
// parameter 0 aliased with output 1 and parameter 1 aliased with output 0.
|
||||||
|
//
|
||||||
|
// (p0 , p1)
|
||||||
|
// \ /
|
||||||
|
// \ /
|
||||||
|
// alias X
|
||||||
|
// / \
|
||||||
|
// / \
|
||||||
|
// (p0 , p1)
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
auto gte0 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
|
||||||
|
auto gte1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
|
||||||
|
auto tuple =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||||
|
module_->AddEntryComputation(builder.Build());
|
||||||
|
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}));
|
||||||
|
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
|
||||||
|
// Cannot alias an output twice.
|
||||||
|
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||||
|
|
||||||
|
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||||
|
|
||||||
|
// Every Ops in this graph are aliased with each other.
|
||||||
|
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
|
||||||
|
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
|
||||||
|
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
|
||||||
|
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
|
||||||
|
|
||||||
|
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
|
||||||
|
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
|
||||||
|
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
|
||||||
|
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
|
||||||
|
// Test a simple single while instruction can be aliased with input and output
|
||||||
|
// of the computation.
|
||||||
|
//
|
||||||
|
// body((F32[], F32[]) %tuple_param):
|
||||||
|
// %add = Add(%tuple_param{0}, %tuple_param{1})
|
||||||
|
// return Tuple(%tuple_param{0}, %add)
|
||||||
|
//
|
||||||
|
// condition((F32[], F32[]) %tuple_param):
|
||||||
|
// return Constant(false)
|
||||||
|
//
|
||||||
|
// entry:
|
||||||
|
// %param1 = param1
|
||||||
|
// %while = While(%param1, body, condition)
|
||||||
|
// %while_1 = GTE(%while, 0)
|
||||||
|
// %while_2 = GTE(%while, 1)
|
||||||
|
// %negate_1 = Negate(%while_1)
|
||||||
|
// %negate_2 = Negate(%while_2)
|
||||||
|
// return Tuple(negate_1, negate_2)
|
||||||
|
//
|
||||||
|
const Shape tuple_shape =
|
||||||
|
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
|
||||||
|
|
||||||
|
// Element 0 passes transparently through the body.
|
||||||
|
auto body_builder = HloComputation::Builder("body");
|
||||||
|
auto body_param = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "param"));
|
||||||
|
auto body_element_0 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
|
||||||
|
auto body_element_1 = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
|
||||||
|
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
|
||||||
|
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
|
||||||
|
auto body_tuple = body_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateTuple({body_element_0, add}));
|
||||||
|
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
|
||||||
|
|
||||||
|
// Condition computation trivially returns a constant "false".
|
||||||
|
auto cond_builder = HloComputation::Builder("condition");
|
||||||
|
auto cond_param = cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "param"));
|
||||||
|
cond_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
|
||||||
|
HloComputation* condition =
|
||||||
|
module_->AddEmbeddedComputation(cond_builder.Build());
|
||||||
|
|
||||||
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
auto param = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||||
|
|
||||||
|
auto xla_while = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateWhile(tuple_shape, condition, body, param));
|
||||||
|
auto while_element_1 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0));
|
||||||
|
auto while_element_2 = builder.AddInstruction(
|
||||||
|
HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1));
|
||||||
|
auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
|
scalar_shape_, HloOpcode::kNegate, while_element_1));
|
||||||
|
auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||||
|
scalar_shape_, HloOpcode::kNegate, while_element_2));
|
||||||
|
auto tuple =
|
||||||
|
builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2}));
|
||||||
|
module_->AddEntryComputation(builder.Build());
|
||||||
|
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||||
|
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||||
|
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||||
|
|
||||||
|
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
|
||||||
|
UnorderedElementsAre(GetValueDefinedAt(param, {1}),
|
||||||
|
GetValueDefinedAt(xla_while, /*index=*/{1}),
|
||||||
|
GetValueDefinedAt(body_param, {1}),
|
||||||
|
GetValueDefinedAt(cond_param, {1}),
|
||||||
|
GetValueDefinedAt(add),
|
||||||
|
GetValueDefinedAt(negate_2)));
|
||||||
|
|
||||||
|
EXPECT_THAT(
|
||||||
|
analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
|
||||||
|
UnorderedElementsAre(
|
||||||
|
HloPosition{param, {1}}, HloPosition{xla_while, {1}},
|
||||||
|
HloPosition{while_element_2, {}}, HloPosition{body_param, {1}},
|
||||||
|
HloPosition{body_element_1, {}}, HloPosition{add, {}},
|
||||||
|
HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}},
|
||||||
|
HloPosition{cond_param, {1}}, HloPosition{negate_2, {}}));
|
||||||
|
|
||||||
|
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(HloAliasAnalysisTest, SingleCall) {
|
TEST_F(HloAliasAnalysisTest, SingleCall) {
|
||||||
// Test a single call of a subcomputation. The subcomputation adds its two
|
// Test a single call of a subcomputation. The subcomputation adds its two
|
||||||
// array-shaped parameters.
|
// array-shaped parameters.
|
||||||
|
@ -126,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
|
|||||||
|
|
||||||
const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
|
const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
|
||||||
const HloInstruction* instruction, const ShapeIndex& index) const {
|
const HloInstruction* instruction, const ShapeIndex& index) const {
|
||||||
CHECK(ValueIsDefinedAt(instruction, index));
|
CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
|
||||||
return GetUniqueValueAt(instruction, index);
|
return GetUniqueValueAt(instruction, index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
182
tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
Normal file
182
tensorflow/compiler/xla/service/hlo_input_output_alias_config.cc
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
||||||
|
int64 param_number,
|
||||||
|
const ShapeIndex& param_index) {
|
||||||
|
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
|
||||||
|
<< absl::StrCat("Tring to set up alias at ", output_index.ToString(),
|
||||||
|
" which is an invalid index for shape ",
|
||||||
|
ShapeUtil::HumanString(alias_.shape()));
|
||||||
|
// Output can't be aliased with multiple parameters.
|
||||||
|
TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
|
||||||
|
"Trying to set up output alias for param %lld at %s but failed: output "
|
||||||
|
"index %s is already aliased with param %lld at %s",
|
||||||
|
param_number, param_index.ToString(), output_index.ToString(),
|
||||||
|
alias_.element(output_index)->first,
|
||||||
|
alias_.element(output_index)->second.ToString());
|
||||||
|
(*alias_.mutable_element(output_index)) =
|
||||||
|
std::make_pair(param_number, param_index);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
|
||||||
|
HloInputOutputAliasProto result;
|
||||||
|
alias_.ForEachElement(
|
||||||
|
[&](const ShapeIndex& index,
|
||||||
|
const absl::optional<std::pair<int64, ShapeIndex>>& data) {
|
||||||
|
if (data) {
|
||||||
|
HloInputOutputAliasProto::AliasEntryProto entry;
|
||||||
|
for (int64 i : index) {
|
||||||
|
entry.add_output_shape_index(i);
|
||||||
|
}
|
||||||
|
entry.set_parameter_number(data->first);
|
||||||
|
for (int64 i : data->second) {
|
||||||
|
entry.add_parameter_shape_index(i);
|
||||||
|
}
|
||||||
|
result.add_entries()->Swap(&entry);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
||||||
|
const Shape& output_shape, const HloInputOutputAliasProto& proto) {
|
||||||
|
HloInputOutputAliasConfig result(output_shape);
|
||||||
|
for (const HloInputOutputAliasProto::AliasEntryProto& entry :
|
||||||
|
proto.entries()) {
|
||||||
|
ShapeIndex output_index(entry.output_shape_index().begin(),
|
||||||
|
entry.output_shape_index().end());
|
||||||
|
|
||||||
|
int64 param_number = entry.parameter_number();
|
||||||
|
ShapeIndex param_index(entry.parameter_shape_index().begin(),
|
||||||
|
entry.parameter_shape_index().end());
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
result.SetUpAlias(output_index, param_number, param_index));
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
string HloInputOutputAliasConfig::ToString() const {
|
||||||
|
std::vector<string> pieces;
|
||||||
|
pieces.push_back("HloInputOutputAliasConfig");
|
||||||
|
|
||||||
|
ForEachAlias([&](const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index) {
|
||||||
|
pieces.push_back(absl::StrFormat(
|
||||||
|
" OutputIndex %s is aliased with parameter %lld at %s:",
|
||||||
|
output_index.ToString(), param_number, param_index.ToString()));
|
||||||
|
});
|
||||||
|
|
||||||
|
return absl::StrJoin(pieces, "\n");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HloInputOutputAliasConfig::ParameterHasAlias(
|
||||||
|
int64 param_number, const ShapeIndex& param_index) const {
|
||||||
|
bool output = false;
|
||||||
|
alias_.ForEachElement(
|
||||||
|
[&](const xla::ShapeIndex&,
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>> alias) {
|
||||||
|
if (alias && alias->first == param_number &&
|
||||||
|
alias->second == param_index) {
|
||||||
|
output = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
|
||||||
|
int64 param_number, const ShapeIndex& param_index) const {
|
||||||
|
absl::optional<ShapeIndex> output;
|
||||||
|
alias_.ForEachElement(
|
||||||
|
[&](const xla::ShapeIndex& output_index,
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>> alias) {
|
||||||
|
if (alias && alias->first == param_number &&
|
||||||
|
alias->second == param_index) {
|
||||||
|
output = output_index;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>>
|
||||||
|
HloInputOutputAliasConfig::GetAliasedParameter(
|
||||||
|
const ShapeIndex& output_index) const {
|
||||||
|
CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
|
||||||
|
return alias_.element(output_index);
|
||||||
|
}
|
||||||
|
|
||||||
|
void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
|
||||||
|
alias_.ForEachElement(
|
||||||
|
[&](const ShapeIndex& output_index,
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>> aliased) {
|
||||||
|
if (aliased) {
|
||||||
|
fn(output_index, aliased->first, aliased->second);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
|
||||||
|
AliasFnWithStatus fn) const {
|
||||||
|
return alias_.ForEachElementWithStatus(
|
||||||
|
[&](const ShapeIndex& output_index,
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>> aliased) {
|
||||||
|
if (aliased) {
|
||||||
|
TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Status HloInputOutputAliasConfig::Verify(const HloModule& module) const {
|
||||||
|
std::vector<ShapeTree<bool>> param_has_seen;
|
||||||
|
const HloComputation* entry = module.entry_computation();
|
||||||
|
for (int64 i = 0; i < entry->num_parameters(); ++i) {
|
||||||
|
HloInstruction* param = entry->parameter_instruction(i);
|
||||||
|
param_has_seen.emplace_back(param->shape());
|
||||||
|
}
|
||||||
|
return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
|
||||||
|
int64 param_number,
|
||||||
|
const ShapeIndex& param_index) -> Status {
|
||||||
|
const HloInstruction* root = entry->root_instruction();
|
||||||
|
|
||||||
|
const Shape& param_shape =
|
||||||
|
entry->parameter_instruction(param_number)->shape();
|
||||||
|
const Shape& output_shape = root->shape();
|
||||||
|
TF_RET_CHECK(entry->num_parameters() > param_number);
|
||||||
|
TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index));
|
||||||
|
TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
|
||||||
|
|
||||||
|
// Check each param_number and param_index pair only show up once. No
|
||||||
|
// input can be aliased with output buffers.
|
||||||
|
TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false);
|
||||||
|
|
||||||
|
*(param_has_seen[param_number].mutable_element(param_index)) = true;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& out,
|
||||||
|
const HloInputOutputAliasConfig& config) {
|
||||||
|
out << config.ToString();
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
} // namespace xla
|
102
tensorflow/compiler/xla/service/hlo_input_output_alias_config.h
Normal file
102
tensorflow/compiler/xla/service/hlo_input_output_alias_config.h
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
|
||||||
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
|
||||||
|
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
class HloModule;
|
||||||
|
|
||||||
|
// This class specifies the alias map from output index to parameter number and
|
||||||
|
// parameter index in the entry computation.
|
||||||
|
class HloInputOutputAliasConfig {
|
||||||
|
public:
|
||||||
|
HloInputOutputAliasConfig() = default;
|
||||||
|
|
||||||
|
explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {}
|
||||||
|
|
||||||
|
virtual ~HloInputOutputAliasConfig() = default;
|
||||||
|
|
||||||
|
// Sets up alias config from `output_index` to `param_index` at
|
||||||
|
// `param_number`.
|
||||||
|
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index);
|
||||||
|
|
||||||
|
// Returns true if the given parameter is aliased with one of the output
|
||||||
|
// buffers.
|
||||||
|
bool ParameterHasAlias(int64 param_number,
|
||||||
|
const ShapeIndex& param_index) const;
|
||||||
|
|
||||||
|
// (De)Serializes an HloInputOutoutAliasConfig to/from an
|
||||||
|
// HloInputOutoutAliasProto.
|
||||||
|
HloInputOutputAliasProto ToProto() const;
|
||||||
|
|
||||||
|
static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
|
||||||
|
const Shape& output_shape, const HloInputOutputAliasProto& proto);
|
||||||
|
|
||||||
|
// Returns the output index that the given parameter and parameter index is
|
||||||
|
// aliased with. A nullopt is returned if there is no output that is aliased
|
||||||
|
// with the parameter number and index.
|
||||||
|
absl::optional<ShapeIndex> GetAliasedOutput(
|
||||||
|
int64 param_number, const ShapeIndex& param_index) const;
|
||||||
|
|
||||||
|
// Returns the number of parameter and index of the parameter buffer that the
|
||||||
|
// given output buffer index is aliased with. A nullopt is returned if there
|
||||||
|
// is no parameter is aliased with the specific output.
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>> GetAliasedParameter(
|
||||||
|
const ShapeIndex& output_index) const;
|
||||||
|
|
||||||
|
using AliasFn =
|
||||||
|
std::function<void(const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index)>;
|
||||||
|
|
||||||
|
// Iterates through each aliased output and input.
|
||||||
|
void ForEachAlias(AliasFn fn) const;
|
||||||
|
|
||||||
|
using AliasFnWithStatus =
|
||||||
|
std::function<Status(const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index)>;
|
||||||
|
|
||||||
|
// Verifies that the given config is valid for the given module.
|
||||||
|
// Specifically, the config's input and output should be in-bound and size of
|
||||||
|
// the aliased buffers should match.
|
||||||
|
Status Verify(const HloModule& module) const;
|
||||||
|
|
||||||
|
Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
|
||||||
|
|
||||||
|
string ToString() const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
// A ShapeTree which indicates the list of buffers that's expected to be
|
||||||
|
// aliased. The key on this shape tree represents the output index. The value
|
||||||
|
// is a pair of parameter number and index into the buffer. If the value is
|
||||||
|
// nullopt, it means there is no parameter aliasing for this output.
|
||||||
|
ShapeTree<absl::optional<std::pair<int64, ShapeIndex>>> alias_;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& out,
|
||||||
|
const HloInputOutputAliasConfig& config);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
|
@ -0,0 +1,184 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/algorithm/container.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace {
|
||||||
|
class HloInputOutputAliasConfigTest : public HloTestBase {
|
||||||
|
protected:
|
||||||
|
void expect_aliased(const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index,
|
||||||
|
const HloInputOutputAliasConfig& config) {
|
||||||
|
absl::optional<ShapeIndex> aliased_output =
|
||||||
|
config.GetAliasedOutput(param_number, param_index);
|
||||||
|
|
||||||
|
EXPECT_TRUE(aliased_output);
|
||||||
|
EXPECT_EQ(aliased_output.value(), output_index);
|
||||||
|
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
|
||||||
|
config.GetAliasedParameter(output_index);
|
||||||
|
|
||||||
|
EXPECT_TRUE(aliased_param);
|
||||||
|
EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index));
|
||||||
|
}
|
||||||
|
|
||||||
|
void expect_not_aliased(const ShapeIndex& output_index, int64 param_number,
|
||||||
|
const ShapeIndex& param_index,
|
||||||
|
const HloInputOutputAliasConfig& config) {
|
||||||
|
absl::optional<ShapeIndex> aliased_output =
|
||||||
|
config.GetAliasedOutput(param_number, param_index);
|
||||||
|
|
||||||
|
EXPECT_FALSE(aliased_output && aliased_output == output_index);
|
||||||
|
|
||||||
|
absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
|
||||||
|
config.GetAliasedParameter(output_index);
|
||||||
|
|
||||||
|
EXPECT_FALSE(aliased_param && aliased_param->first == param_number &&
|
||||||
|
aliased_param->second == param_index);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) {
|
||||||
|
const string module_str = R"(
|
||||||
|
HloModule TEST
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
a = f32[] parameter(0)
|
||||||
|
b = f32[] parameter(1)
|
||||||
|
ROOT root = (f32[], f32[]) tuple(%a, %b)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(module_str));
|
||||||
|
|
||||||
|
HloInputOutputAliasConfig config(
|
||||||
|
module->entry_computation()->root_instruction()->shape());
|
||||||
|
|
||||||
|
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
|
||||||
|
/*param_index=*/{}));
|
||||||
|
|
||||||
|
expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
|
||||||
|
/*param_index=*/{}, config);
|
||||||
|
|
||||||
|
expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
|
||||||
|
/*param_index=*/{}, config);
|
||||||
|
|
||||||
|
expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{}, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) {
|
||||||
|
const string module_str = R"(
|
||||||
|
HloModule TEST
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
param = (f32[], f32[]) parameter(0)
|
||||||
|
gte1 = f32[] get-tuple-element(%param), index=0
|
||||||
|
gte2 = f32[] get-tuple-element(%param), index=1
|
||||||
|
ROOT root = (f32[], f32[]) tuple(%gte1, %gte2)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(module_str));
|
||||||
|
|
||||||
|
HloInputOutputAliasConfig config(
|
||||||
|
module->entry_computation()->root_instruction()->shape());
|
||||||
|
|
||||||
|
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{0}));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{1}));
|
||||||
|
|
||||||
|
expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{0}, config);
|
||||||
|
|
||||||
|
expect_aliased(/*output_index=*/{1}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{1}, config);
|
||||||
|
|
||||||
|
expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
|
||||||
|
/*param_index=*/{}, config);
|
||||||
|
|
||||||
|
expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{}, config);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) {
|
||||||
|
const string module_str = R"(
|
||||||
|
HloModule TEST
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
a = f32[] parameter(0)
|
||||||
|
b = f32[] parameter(1)
|
||||||
|
ROOT root = (f32[], f32[]) tuple(%a, %b)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(module_str));
|
||||||
|
|
||||||
|
HloInputOutputAliasConfig config(
|
||||||
|
module->entry_computation()->root_instruction()->shape());
|
||||||
|
|
||||||
|
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{}));
|
||||||
|
|
||||||
|
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{}));
|
||||||
|
|
||||||
|
ASSERT_IS_NOT_OK(config.Verify(*module));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) {
|
||||||
|
const string module_str = R"(
|
||||||
|
HloModule TEST
|
||||||
|
|
||||||
|
ENTRY main {
|
||||||
|
a = f32[] parameter(0)
|
||||||
|
b = f32[] parameter(1)
|
||||||
|
ROOT root = (f32[], f32[]) tuple(%a, %b)
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(module_str));
|
||||||
|
|
||||||
|
HloInputOutputAliasConfig config(
|
||||||
|
module->entry_computation()->root_instruction()->shape());
|
||||||
|
|
||||||
|
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
|
||||||
|
/*param_index=*/{}));
|
||||||
|
|
||||||
|
ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
|
||||||
|
/*param_index=*/{}));
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
} // namespace xla
|
@ -73,6 +73,8 @@ HloComputation* HloModule::AddComputationInternal(
|
|||||||
config_.SetDefaultComputationLayout(
|
config_.SetDefaultComputationLayout(
|
||||||
entry_computation_->ComputeProgramShape());
|
entry_computation_->ComputeProgramShape());
|
||||||
}
|
}
|
||||||
|
input_output_alias_config_ = HloInputOutputAliasConfig(
|
||||||
|
entry_computation_->root_instruction()->shape());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (uniquify_identifiers) {
|
if (uniquify_identifiers) {
|
||||||
@ -252,6 +254,9 @@ HloModuleProto HloModule::ToProto() const {
|
|||||||
if (has_schedule()) {
|
if (has_schedule()) {
|
||||||
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
|
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
|
||||||
|
|
||||||
return proto;
|
return proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -328,6 +333,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
|||||||
}
|
}
|
||||||
TF_RET_CHECK(module->entry_computation_ != nullptr);
|
TF_RET_CHECK(module->entry_computation_ != nullptr);
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(module->input_output_alias_config_,
|
||||||
|
HloInputOutputAliasConfig::CreateFromProto(
|
||||||
|
result_shape, proto.input_output_alias()));
|
||||||
|
|
||||||
// Because we didn't uniquify the names or the ids, double-check that the
|
// Because we didn't uniquify the names or the ids, double-check that the
|
||||||
// instruction and computation names and ids are unique from the proto.
|
// instruction and computation names and ids are unique from the proto.
|
||||||
absl::flat_hash_set<string> computation_names;
|
absl::flat_hash_set<string> computation_names;
|
||||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
|
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||||
@ -222,6 +223,15 @@ class HloModule {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// input_output_alias_config indicates the list of aliased buffers that are
|
||||||
|
// expected from the module.
|
||||||
|
HloInputOutputAliasConfig& input_output_alias_config() {
|
||||||
|
return input_output_alias_config_;
|
||||||
|
}
|
||||||
|
const HloInputOutputAliasConfig& input_output_alias_config() const {
|
||||||
|
return input_output_alias_config_;
|
||||||
|
}
|
||||||
|
|
||||||
// Returns an id that is unique to this module across all modules created over
|
// Returns an id that is unique to this module across all modules created over
|
||||||
// the lifetime of this process.
|
// the lifetime of this process.
|
||||||
int unique_id() const { return unique_id_; }
|
int unique_id() const { return unique_id_; }
|
||||||
@ -290,6 +300,10 @@ class HloModule {
|
|||||||
// sequential order of instructions for each non-fusion computation in the
|
// sequential order of instructions for each non-fusion computation in the
|
||||||
// module.
|
// module.
|
||||||
absl::optional<HloSchedule> schedule_;
|
absl::optional<HloSchedule> schedule_;
|
||||||
|
|
||||||
|
// alias_config indicates the alias information of input/output buffers that
|
||||||
|
// are expected from the module.
|
||||||
|
HloInputOutputAliasConfig input_output_alias_config_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -1316,6 +1316,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
|
|||||||
TF_RETURN_IF_ERROR(module->schedule().Verify());
|
TF_RETURN_IF_ERROR(module->schedule().Verify());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(*module));
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,7 +72,7 @@ class ShapeIndex {
|
|||||||
void push_back(int64 value) { indices_.push_back(value); }
|
void push_back(int64 value) { indices_.push_back(value); }
|
||||||
void pop_back() { indices_.pop_back(); }
|
void pop_back() { indices_.pop_back(); }
|
||||||
|
|
||||||
// push_front is O(n^2), but shapes don't usually have a ton of dimensions.
|
// push_front is O(n), but shapes don't usually have a ton of dimensions.
|
||||||
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
|
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
|
||||||
|
|
||||||
using container_type = absl::InlinedVector<int64, 2>;
|
using container_type = absl::InlinedVector<int64, 2>;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user