[Resubmit][XLA] Introduce input/output alias config.

- This CL intruduces input/output alias config in HLO module that allows any HLO pass to configure it. Once the alias_config is set, each backend needs to follow the contract during execution time to make sure the input and output are indeed aliased.

- Copy insertion / buffer assignment and alias analysis has been updated to correctly honor the config and avoid any possible liveness interference.

PiperOrigin-RevId: 216737975
This commit is contained in:
Yunxing Dai 2018-10-11 12:07:37 -07:00 committed by TensorFlower Gardener
parent c304bd9bc9
commit 028410c7f4
16 changed files with 1078 additions and 27 deletions

View File

@ -294,6 +294,7 @@ cc_library(
srcs = [
"dfs_hlo_visitor.cc",
"hlo_computation.cc",
"hlo_input_output_alias_config.cc",
"hlo_instruction.cc",
"hlo_instructions.cc",
"hlo_module.cc",
@ -308,6 +309,7 @@ cc_library(
"hlo_clone_context.h",
"hlo_computation.h",
"hlo_domain_metadata.h",
"hlo_input_output_alias_config.h",
"hlo_instruction.h",
"hlo_instructions.h",
"hlo_module.h",
@ -1268,6 +1270,25 @@ tf_cc_test(
],
)
tf_cc_test(
name = "hlo_input_output_alias_config_test",
srcs = ["hlo_input_output_alias_config_test.cc"],
deps = [
":hlo",
":hlo_dce",
":hlo_memory_scheduler",
":hlo_ordering",
":hlo_parser",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
],
)
cc_library(
name = "hlo_memory_scheduler",
srcs = ["hlo_memory_scheduler.cc"],

View File

@ -239,7 +239,7 @@ BufferAllocation::Slice BufferAllocation::GetSlice(
void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset,
int64 size) {
VLOG(4) << "Trying to add " << buffer << " to " << this;
VLOG(4) << "Trying to add " << buffer << " to allocation #" << index();
CHECK(assigned_buffers_.count(&buffer) == 0)
<< "LogicalBuffer " << buffer << " already assigned to allocation "
<< index_;
@ -784,21 +784,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
}
}
if (allow_input_output_aliasing_ && allocation->maybe_live_out()) {
const HloComputation* entry_computation =
assignment->module_->entry_computation();
for (auto param : entry_computation->parameter_instructions()) {
for (auto& param_buffer :
assignment->points_to_analysis().GetBuffersDefinedByInstruction(
param)) {
if (assignment->liveness().MayInterfere(*param_buffer, buffer)) {
VLOG(4) << "Can't assign: Parameter interference with result";
return false;
}
}
}
}
// If the buffer is live out of the computation then it should only be
// assigned a buffer which exactly fits the result to avoid wasting memory
// (result buffers can have arbitrary lifetimes).
@ -1434,13 +1419,28 @@ BufferAssigner::MergeColocatedBufferSets(
// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
// in the same allocation (currently just supports kWhile, kCall, and
// kConditional).
// kConditional and input output aliasing).
void BufferAssigner::BuildColocatedBufferSets(
const HloModule* module, const BufferLiveness& buffer_liveness,
const LogicalBuffer::SizeFunction& buffer_size,
std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
const TuplePointsToAnalysis& points_to_analysis =
buffer_liveness.points_to_analysis();
// Set up colocated buffer set for input and output.
module->input_output_alias_config().ForEachAlias(
[&](const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index) {
std::vector<const LogicalBuffer*> colocated_set;
AddBufferToColocatedSet(module->entry_computation()->root_instruction(),
output_index, points_to_analysis,
&colocated_set);
AddBufferToColocatedSet(
module->entry_computation()->parameter_instruction(param_number),
param_index, points_to_analysis, &colocated_set);
AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
});
for (const HloComputation* computation : module->MakeComputationPostOrder()) {
if (computation->IsFusionComputation()) {
continue;

View File

@ -141,6 +141,9 @@ class BufferValue {
// operator< is required for std::set.
bool operator<(const BufferValue& other) const { return id_ < other.id_; }
bool operator==(const BufferValue& other) const { return id_ == other.id_; }
bool operator!=(const BufferValue& other) const { return id_ != other.id_; }
virtual string ToString() const = 0;
// TODO(lauj) rename LogicalBufferProto to BufferValueProto.

View File

@ -40,10 +40,12 @@ namespace {
using absl::StrAppend;
bool IsEntryParameterValue(const HloValue& value) {
bool IsReadonlyEntryParameterValue(const HloValue& value) {
const HloComputation* computation = value.defining_instruction()->parent();
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
computation == computation->parent()->entry_computation();
computation == computation->parent()->entry_computation() &&
!computation->parent()->input_output_alias_config().ParameterHasAlias(
value.defining_instruction()->parameter_number(), value.index());
}
bool IsConstantValue(const HloValue& value) {
@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) {
}
bool ValueIsReadOnly(const HloValue& value) {
return IsConstantValue(value) || IsEntryParameterValue(value);
return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
}
// Data structure describing the action which should be taken on parts of a
@ -79,8 +81,7 @@ SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
bool ShouldCopyRootValue(const HloValue& value,
const SpecialCaseCopyPolicy& policy) {
if (policy.copy_parameters_and_constants) {
return IsConstantValue(value) ||
value.defining_instruction()->opcode() == HloOpcode::kParameter;
return ValueIsReadOnly(value);
}
return false;
}
@ -332,6 +333,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
// Conservatively adds copies before root instruction of entry computation and
// each aliased parameter to resolve interference of aliased input and output
// buffer. We later rely on the CopyRemover to drop the unnecessary ones.
Status AddCopiesForAliasedInputOutputs(HloModule* module) {
HloComputation* entry = module->entry_computation();
HloInstruction* root = entry->root_instruction();
ShapeTree<bool> output_indices_to_copy(root->shape());
std::vector<ShapeTree<HloInstruction*>> copied_parameters;
bool has_alias = false;
for (auto* param : entry->parameter_instructions()) {
bool param_has_alias = false;
ShapeTree<bool> param_indices_to_copy(param->shape());
module->input_output_alias_config().ForEachAlias(
[&](const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index) {
if (param_number == param->parameter_number()) {
param_has_alias = true;
*(param_indices_to_copy.mutable_element(param_index)) = true;
*(output_indices_to_copy.mutable_element(output_index)) = true;
}
});
if (!param_has_alias) {
continue;
}
has_alias = true;
// Store a snapshot of users before DeepCopyInstruction, as
// DeepCopyInstruction introduces new users of the instruction.
std::vector<HloInstruction*> users = param->users();
ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
/*init_value=*/nullptr);
TF_ASSIGN_OR_RETURN(HloInstruction * copied,
entry->DeepCopyInstruction(
param, &param_indices_to_copy, &param_copy_tree));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
}
copied_parameters.push_back(param_copy_tree);
}
if (!has_alias) {
return Status::OK();
}
// Add copies before root instruction.
ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
/*init_value=*/nullptr);
TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
root->parent()->DeepCopyInstruction(
root, &output_indices_to_copy, &output_copy_tree));
// Add control dependencies between the input/output copies.
TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
[&](const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& input_index) -> Status {
HloInstruction* from =
copied_parameters[param_number].element(input_index);
HloInstruction* to = output_copy_tree.element(output_index);
TF_RET_CHECK(from != nullptr);
TF_RET_CHECK(to != nullptr);
TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
return Status::OK();
}));
entry->set_root_instruction(root_copied);
return Status::OK();
}
// Removes any control dependencies to or from the given instruction.
Status StripControlDependenciesFrom(HloInstruction* instruction) {
while (!instruction->control_successors().empty()) {
@ -953,6 +1029,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
}
}
}
TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
return Status::OK();
}

View File

@ -1351,6 +1351,218 @@ TEST_F(CopyInsertionTest, SwizzlingWhile) {
EXPECT_THAT(xla_while->operand(0), op::Tuple(op::Copy(), op::Copy()));
}
TEST_F(CopyInsertionTest, CrossingParameters) {
// Test a case where two parameters' dataflow cross with each other while
// input and output are aliased with same index:
//
// (p0 , p1)
// | \ /|
// | \ / |
// alias X alias
// | / \ |
// | / \|
// (p1 , p0)
auto module = CreateNewModule();
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 4);
}
TEST_F(CopyInsertionTest, ParametersAliasing) {
// Test a case where two parameters' dataflow don't interfere with each other
// while aliased.
//
// (p0 , p1)
// | |
// | |
// alias alias
// | |
// | |
// (p0 , p1)
auto module = CreateNewModule();
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, ParameterWithNoAliasing) {
// Test a case where no parameter is aliased with result. In this case, copy
// should be added
//
// (p0 , p1)
// | |
// | |
// | |
// | |
// | |
// (p0 , p1)
auto module = CreateNewModule();
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module->AddEntryComputation(builder.Build());
InsertCopies(module.get());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::Copy(op::GetTupleElement(param, 0)),
op::Copy(op::GetTupleElement(param, 1))));
EXPECT_EQ(CountCopies(*module), 2);
}
TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
// Test a case where one parameter is aliased with result while another one
// isn't.
//
// (p0 , p1)
// | |
// | |
// alias |
// | |
// | |
// (p0 , p1)
auto module = CreateNewModule();
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
InsertCopies(module.get());
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::GetTupleElement(param, 0),
op::Copy(op::GetTupleElement(param, 1))));
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
// Test a case where one parameter is aliased with result while another one
// isn't.
//
// +-- (p0 , p1)
// | | |
// | | |
// alias Negate Negate
// | | |
// | | |
// +-- (p0 , p1)
auto module = CreateNewModule();
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
auto negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
// Test a case where one parameter is aliased with result while another one
// isn't.
//
// +-- (p0 , p1)
// | | |
// | | |
// alias Negate Negate
// | | |
// | Add----+
// | | |
// +-- (p0 , p1)
auto module = CreateNewModule();
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
auto negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
auto add = builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, negate0, negate1));
builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, SwizzlingWhileWithOneOp) {
// Test a while instruction with a body which permutes its tuple parameter
// elements and applies one operation to one of the elements. The addition of

View File

@ -225,6 +225,32 @@ message HloScheduleProto {
map<int64, InstructionSequence> sequences = 1;
}
message HloInputOutputAliasProto {
// The following proto describes a pair of aliased an input
// (described by parameter number and a ShapeIndex of the parameter)
// and an output (described by a ShapeIndex of the root
// instruction). For example:
//
// entry = {
// output_shape_index={1},
// parameter_number=0,
// parameter_shape_index={1, 2},
// }
//
// This entry indicates that the first paremter's {1, 2} element is
// aliased with the {1} element of the root instruction.
message AliasEntryProto {
// ShapeIndex of the root hlo.
repeated int64 output_shape_index = 1;
// Number of the parameter in entry computation.
int64 parameter_number = 2;
// ShapeIndex of the parameter instruction.
repeated int64 parameter_shape_index = 3;
}
repeated AliasEntryProto entries = 1;
}
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@ -243,6 +269,9 @@ message HloModuleProto {
// The schedule for this module.
HloScheduleProto schedule = 7;
// Describes alias information between inputs and outputs.
HloInputOutputAliasProto input_output_alias = 8;
}
// Serialization of LogicalBuffer.

View File

@ -59,8 +59,9 @@ class BufferValueMap {
// construction process.
using BufferNumber = int64;
explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
: dataflow_(dataflow) {
explicit BufferValueMap(HloModule* module,
const HloDataflowAnalysis& dataflow)
: module_(module), dataflow_(dataflow) {
buffers_.reserve(dataflow_.values().size());
value_to_buffer_number_.reserve(dataflow_.values().size());
for (const HloValue* value : dataflow_.values()) {
@ -171,6 +172,42 @@ class BufferValueMap {
return value_to_buffer_number_.at(&value);
}
void ComputeInputOutputAliasedBuffers(
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
// Get parameter value from an aliased_input object.
const auto get_parameter_value =
[this](const std::pair<int64, ShapeIndex>& aliased_input)
-> const HloValue& {
int64 param_number = aliased_input.first;
const ShapeIndex& param_index = aliased_input.second;
return dataflow_.GetUniqueValueAt(
module_->entry_computation()->parameter_instruction(param_number),
param_index);
};
// If the value shows up in a root instruction, alias it with parameter
// intruction.
for (const HloPosition& pos : value.positions()) {
if (pos.instruction == module_->entry_computation()->root_instruction()) {
ShapeIndex output_index = pos.index;
auto aliased_input =
module_->input_output_alias_config().GetAliasedParameter(
output_index);
if (aliased_input) {
aliased_buffers->push_back(
GetBufferForValue(get_parameter_value(*aliased_input)));
}
}
}
// If the value is parameter instruction itself, alias it with itself.
if (value.instruction()->opcode() == HloOpcode::kParameter &&
value.instruction()->parent() == module_->entry_computation()) {
aliased_buffers->push_back(GetBufferForValue(value));
}
}
void ComputeWhileAliasedBuffers(const HloValue& value,
std::vector<BufferNumber>* aliased_buffers) {
VLOG(3) << "Compute kWhile aliases";
@ -278,6 +315,7 @@ class BufferValueMap {
VLOG(2) << "Use of value " << value.ToShortString() << ": " << use;
}
std::vector<BufferNumber> aliased_buffers;
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
ComputeWhileAliasedBuffers(value, &aliased_buffers);
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers.
@ -288,6 +326,8 @@ class BufferValueMap {
return aliased_buffers;
}
HloModule* module_;
// Dataflow analysis used to construct the buffer map.
const HloDataflowAnalysis& dataflow_;
@ -461,7 +501,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
/*bitcast_defines_value=*/false,
fusion_can_share_buffer));
BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();
// Create a vector of HloBuffers, one for each set of values in the

View File

@ -217,6 +217,181 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) {
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
auto negate0 = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte0));
auto negate1 = builder.AddInstruction(
HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, gte1));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
module_->AddEntryComputation(builder.Build());
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
// Cannot alias an output twice.
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
}
TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
// parameter 0 aliased with output 1 and parameter 1 aliased with output 0.
//
// (p0 , p1)
// \ /
// \ /
// alias X
// / \
// / \
// (p0 , p1)
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto gte0 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 0));
auto gte1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module_->AddEntryComputation(builder.Build());
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}));
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
// Cannot alias an output twice.
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
const HloAliasAnalysis& analysis = RunAnalysis();
// Every Ops in this graph are aliased with each other.
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
EXPECT_EQ(analysis.GetUniqueBufferAt(gte0),
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
analysis.GetUniqueBufferAt(tuple, /*index=*/{0}));
EXPECT_EQ(analysis.GetUniqueBufferAt(gte1),
analysis.GetUniqueBufferAt(tuple, /*index=*/{1}));
}
TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
// Test a simple single while instruction can be aliased with input and output
// of the computation.
//
// body((F32[], F32[]) %tuple_param):
// %add = Add(%tuple_param{0}, %tuple_param{1})
// return Tuple(%tuple_param{0}, %add)
//
// condition((F32[], F32[]) %tuple_param):
// return Constant(false)
//
// entry:
// %param1 = param1
// %while = While(%param1, body, condition)
// %while_1 = GTE(%while, 0)
// %while_2 = GTE(%while, 1)
// %negate_1 = Negate(%while_1)
// %negate_2 = Negate(%while_2)
// return Tuple(negate_1, negate_2)
//
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_});
// Element 0 passes transparently through the body.
auto body_builder = HloComputation::Builder("body");
auto body_param = body_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
auto body_element_0 = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 0));
auto body_element_1 = body_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1));
auto add = body_builder.AddInstruction(HloInstruction::CreateBinary(
scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1));
auto body_tuple = body_builder.AddInstruction(
HloInstruction::CreateTuple({body_element_0, add}));
HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
// Condition computation trivially returns a constant "false".
auto cond_builder = HloComputation::Builder("condition");
auto cond_param = cond_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "param"));
cond_builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
HloComputation* condition =
module_->AddEmbeddedComputation(cond_builder.Build());
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto xla_while = builder.AddInstruction(
HloInstruction::CreateWhile(tuple_shape, condition, body, param));
auto while_element_1 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 0));
auto while_element_2 = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(scalar_shape_, xla_while, 1));
auto negate_1 = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, while_element_1));
auto negate_2 = builder.AddInstruction(HloInstruction::CreateUnary(
scalar_shape_, HloOpcode::kNegate, while_element_2));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({negate_1, negate_2}));
module_->AddEntryComputation(builder.Build());
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
const HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
UnorderedElementsAre(GetValueDefinedAt(param, {1}),
GetValueDefinedAt(xla_while, /*index=*/{1}),
GetValueDefinedAt(body_param, {1}),
GetValueDefinedAt(cond_param, {1}),
GetValueDefinedAt(add),
GetValueDefinedAt(negate_2)));
EXPECT_THAT(
analysis.GetUniqueBufferAt(xla_while, /*index=*/{1}).ComputePositions(),
UnorderedElementsAre(
HloPosition{param, {1}}, HloPosition{xla_while, {1}},
HloPosition{while_element_2, {}}, HloPosition{body_param, {1}},
HloPosition{body_element_1, {}}, HloPosition{add, {}},
HloPosition{body_tuple, {1}}, HloPosition{tuple, {1}},
HloPosition{cond_param, {1}}, HloPosition{negate_2, {}}));
EXPECT_FALSE(AnyValuesInSameBufferInterfere());
}
TEST_F(HloAliasAnalysisTest, SingleCall) {
// Test a single call of a subcomputation. The subcomputation adds its two
// array-shaped parameters.

View File

@ -126,7 +126,7 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
const HloInstruction* instruction, const ShapeIndex& index) const {
CHECK(ValueIsDefinedAt(instruction, index));
CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
return GetUniqueValueAt(instruction, index);
}

View File

@ -0,0 +1,182 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
namespace xla {
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
int64 param_number,
const ShapeIndex& param_index) {
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
<< absl::StrCat("Tring to set up alias at ", output_index.ToString(),
" which is an invalid index for shape ",
ShapeUtil::HumanString(alias_.shape()));
// Output can't be aliased with multiple parameters.
TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
"Trying to set up output alias for param %lld at %s but failed: output "
"index %s is already aliased with param %lld at %s",
param_number, param_index.ToString(), output_index.ToString(),
alias_.element(output_index)->first,
alias_.element(output_index)->second.ToString());
(*alias_.mutable_element(output_index)) =
std::make_pair(param_number, param_index);
return Status::OK();
}
HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
HloInputOutputAliasProto result;
alias_.ForEachElement(
[&](const ShapeIndex& index,
const absl::optional<std::pair<int64, ShapeIndex>>& data) {
if (data) {
HloInputOutputAliasProto::AliasEntryProto entry;
for (int64 i : index) {
entry.add_output_shape_index(i);
}
entry.set_parameter_number(data->first);
for (int64 i : data->second) {
entry.add_parameter_shape_index(i);
}
result.add_entries()->Swap(&entry);
}
});
return result;
}
StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
const Shape& output_shape, const HloInputOutputAliasProto& proto) {
HloInputOutputAliasConfig result(output_shape);
for (const HloInputOutputAliasProto::AliasEntryProto& entry :
proto.entries()) {
ShapeIndex output_index(entry.output_shape_index().begin(),
entry.output_shape_index().end());
int64 param_number = entry.parameter_number();
ShapeIndex param_index(entry.parameter_shape_index().begin(),
entry.parameter_shape_index().end());
TF_RETURN_IF_ERROR(
result.SetUpAlias(output_index, param_number, param_index));
}
return result;
}
string HloInputOutputAliasConfig::ToString() const {
std::vector<string> pieces;
pieces.push_back("HloInputOutputAliasConfig");
ForEachAlias([&](const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index) {
pieces.push_back(absl::StrFormat(
" OutputIndex %s is aliased with parameter %lld at %s:",
output_index.ToString(), param_number, param_index.ToString()));
});
return absl::StrJoin(pieces, "\n");
}
bool HloInputOutputAliasConfig::ParameterHasAlias(
int64 param_number, const ShapeIndex& param_index) const {
bool output = false;
alias_.ForEachElement(
[&](const xla::ShapeIndex&,
absl::optional<std::pair<int64, ShapeIndex>> alias) {
if (alias && alias->first == param_number &&
alias->second == param_index) {
output = true;
}
});
return output;
}
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
int64 param_number, const ShapeIndex& param_index) const {
absl::optional<ShapeIndex> output;
alias_.ForEachElement(
[&](const xla::ShapeIndex& output_index,
absl::optional<std::pair<int64, ShapeIndex>> alias) {
if (alias && alias->first == param_number &&
alias->second == param_index) {
output = output_index;
}
});
return output;
}
absl::optional<std::pair<int64, ShapeIndex>>
HloInputOutputAliasConfig::GetAliasedParameter(
const ShapeIndex& output_index) const {
CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
return alias_.element(output_index);
}
void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
alias_.ForEachElement(
[&](const ShapeIndex& output_index,
absl::optional<std::pair<int64, ShapeIndex>> aliased) {
if (aliased) {
fn(output_index, aliased->first, aliased->second);
}
});
}
Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
AliasFnWithStatus fn) const {
return alias_.ForEachElementWithStatus(
[&](const ShapeIndex& output_index,
absl::optional<std::pair<int64, ShapeIndex>> aliased) {
if (aliased) {
TF_RETURN_IF_ERROR(fn(output_index, aliased->first, aliased->second));
}
return Status::OK();
});
}
Status HloInputOutputAliasConfig::Verify(const HloModule& module) const {
std::vector<ShapeTree<bool>> param_has_seen;
const HloComputation* entry = module.entry_computation();
for (int64 i = 0; i < entry->num_parameters(); ++i) {
HloInstruction* param = entry->parameter_instruction(i);
param_has_seen.emplace_back(param->shape());
}
return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
int64 param_number,
const ShapeIndex& param_index) -> Status {
const HloInstruction* root = entry->root_instruction();
const Shape& param_shape =
entry->parameter_instruction(param_number)->shape();
const Shape& output_shape = root->shape();
TF_RET_CHECK(entry->num_parameters() > param_number);
TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, param_index));
TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
// Check each param_number and param_index pair only show up once. No
// input can be aliased with output buffers.
TF_RET_CHECK(param_has_seen[param_number].element(param_index) == false);
*(param_has_seen[param_number].mutable_element(param_index)) = true;
return Status::OK();
});
}
std::ostream& operator<<(std::ostream& out,
const HloInputOutputAliasConfig& config) {
out << config.ToString();
return out;
}
} // namespace xla

View File

@ -0,0 +1,102 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_
#include <utility>
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
class HloModule;
// This class specifies the alias map from output index to parameter number and
// parameter index in the entry computation.
class HloInputOutputAliasConfig {
public:
HloInputOutputAliasConfig() = default;
explicit HloInputOutputAliasConfig(Shape shape) : alias_(shape) {}
virtual ~HloInputOutputAliasConfig() = default;
// Sets up alias config from `output_index` to `param_index` at
// `param_number`.
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index);
// Returns true if the given parameter is aliased with one of the output
// buffers.
bool ParameterHasAlias(int64 param_number,
const ShapeIndex& param_index) const;
// (De)Serializes an HloInputOutoutAliasConfig to/from an
// HloInputOutoutAliasProto.
HloInputOutputAliasProto ToProto() const;
static StatusOr<HloInputOutputAliasConfig> CreateFromProto(
const Shape& output_shape, const HloInputOutputAliasProto& proto);
// Returns the output index that the given parameter and parameter index is
// aliased with. A nullopt is returned if there is no output that is aliased
// with the parameter number and index.
absl::optional<ShapeIndex> GetAliasedOutput(
int64 param_number, const ShapeIndex& param_index) const;
// Returns the number of parameter and index of the parameter buffer that the
// given output buffer index is aliased with. A nullopt is returned if there
// is no parameter is aliased with the specific output.
absl::optional<std::pair<int64, ShapeIndex>> GetAliasedParameter(
const ShapeIndex& output_index) const;
using AliasFn =
std::function<void(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index)>;
// Iterates through each aliased output and input.
void ForEachAlias(AliasFn fn) const;
using AliasFnWithStatus =
std::function<Status(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index)>;
// Verifies that the given config is valid for the given module.
// Specifically, the config's input and output should be in-bound and size of
// the aliased buffers should match.
Status Verify(const HloModule& module) const;
Status ForEachAliasWithStatus(AliasFnWithStatus fn) const;
string ToString() const;
private:
// A ShapeTree which indicates the list of buffers that's expected to be
// aliased. The key on this shape tree represents the output index. The value
// is a pair of parameter number and index into the buffer. If the value is
// nullopt, it means there is no parameter aliasing for this output.
ShapeTree<absl::optional<std::pair<int64, ShapeIndex>>> alias_;
};
std::ostream& operator<<(std::ostream& out,
const HloInputOutputAliasConfig& config);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INPUT_OUTPUT_ALIAS_CONFIG_H_

View File

@ -0,0 +1,184 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include <memory>
#include <string>
#include "absl/algorithm/container.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
class HloInputOutputAliasConfigTest : public HloTestBase {
protected:
void expect_aliased(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index,
const HloInputOutputAliasConfig& config) {
absl::optional<ShapeIndex> aliased_output =
config.GetAliasedOutput(param_number, param_index);
EXPECT_TRUE(aliased_output);
EXPECT_EQ(aliased_output.value(), output_index);
absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
config.GetAliasedParameter(output_index);
EXPECT_TRUE(aliased_param);
EXPECT_EQ(aliased_param.value(), std::make_pair(param_number, param_index));
}
void expect_not_aliased(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index,
const HloInputOutputAliasConfig& config) {
absl::optional<ShapeIndex> aliased_output =
config.GetAliasedOutput(param_number, param_index);
EXPECT_FALSE(aliased_output && aliased_output == output_index);
absl::optional<std::pair<int64, ShapeIndex>> aliased_param =
config.GetAliasedParameter(output_index);
EXPECT_FALSE(aliased_param && aliased_param->first == param_number &&
aliased_param->second == param_index);
}
};
TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) {
const string module_str = R"(
HloModule TEST
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT root = (f32[], f32[]) tuple(%a, %b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
/*param_index=*/{}));
expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
/*param_index=*/{}, config);
expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
/*param_index=*/{}, config);
expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{}, config);
}
TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) {
const string module_str = R"(
HloModule TEST
ENTRY main {
param = (f32[], f32[]) parameter(0)
gte1 = f32[] get-tuple-element(%param), index=0
gte2 = f32[] get-tuple-element(%param), index=1
ROOT root = (f32[], f32[]) tuple(%gte1, %gte2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{0}));
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
/*param_index=*/{1}));
expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{0}, config);
expect_aliased(/*output_index=*/{1}, /*param_number=*/0,
/*param_index=*/{1}, config);
expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
/*param_index=*/{}, config);
expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{}, config);
}
TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) {
const string module_str = R"(
HloModule TEST
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT root = (f32[], f32[]) tuple(%a, %b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{}));
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{1}, /*param_number=*/0,
/*param_index=*/{}));
ASSERT_IS_NOT_OK(config.Verify(*module));
}
TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) {
const string module_str = R"(
HloModule TEST
ENTRY main {
a = f32[] parameter(0)
b = f32[] parameter(1)
ROOT root = (f32[], f32[]) tuple(%a, %b)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(module_str));
HloInputOutputAliasConfig config(
module->entry_computation()->root_instruction()->shape());
TF_ASSERT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{}));
ASSERT_IS_NOT_OK(config.SetUpAlias(/*output_index=*/{0}, /*param_number=*/1,
/*param_index=*/{}));
}
} // namespace
} // namespace xla

View File

@ -73,6 +73,8 @@ HloComputation* HloModule::AddComputationInternal(
config_.SetDefaultComputationLayout(
entry_computation_->ComputeProgramShape());
}
input_output_alias_config_ = HloInputOutputAliasConfig(
entry_computation_->root_instruction()->shape());
}
if (uniquify_identifiers) {
@ -252,6 +254,9 @@ HloModuleProto HloModule::ToProto() const {
if (has_schedule()) {
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
}
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
return proto;
}
@ -328,6 +333,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
TF_ASSIGN_OR_RETURN(module->input_output_alias_config_,
HloInputOutputAliasConfig::CreateFromProto(
result_shape, proto.input_output_alias()));
// Because we didn't uniquify the names or the ids, double-check that the
// instruction and computation names and ids are unique from the proto.
absl::flat_hash_set<string> computation_names;

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
@ -222,6 +223,15 @@ class HloModule {
return result;
}
// input_output_alias_config indicates the list of aliased buffers that are
// expected from the module.
HloInputOutputAliasConfig& input_output_alias_config() {
return input_output_alias_config_;
}
const HloInputOutputAliasConfig& input_output_alias_config() const {
return input_output_alias_config_;
}
// Returns an id that is unique to this module across all modules created over
// the lifetime of this process.
int unique_id() const { return unique_id_; }
@ -290,6 +300,10 @@ class HloModule {
// sequential order of instructions for each non-fusion computation in the
// module.
absl::optional<HloSchedule> schedule_;
// alias_config indicates the alias information of input/output buffers that
// are expected from the module.
HloInputOutputAliasConfig input_output_alias_config_;
};
} // namespace xla

View File

@ -1316,6 +1316,8 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(module->schedule().Verify());
}
TF_RETURN_IF_ERROR(module->input_output_alias_config().Verify(*module));
return false;
}

View File

@ -72,7 +72,7 @@ class ShapeIndex {
void push_back(int64 value) { indices_.push_back(value); }
void pop_back() { indices_.pop_back(); }
// push_front is O(n^2), but shapes don't usually have a ton of dimensions.
// push_front is O(n), but shapes don't usually have a ton of dimensions.
void push_front(int64 value) { indices_.insert(indices_.begin(), value); }
using container_type = absl::InlinedVector<int64, 2>;