[XLA] Unify aliasing types
SYSTEM/USER alias distinction is not actually used, and knowing it at compile time does not bring any advantages, as we check actual aliasing at runtime in any case. PiperOrigin-RevId: 320079893 Change-Id: I726cfe9dae0256904778a3bc3e501566aa026f9f
This commit is contained in:
parent
66668ec0ca
commit
0103bdb3cd
@ -244,9 +244,7 @@ static bool MustAliasOutput(
|
||||
if (input_output_alias.shape().tuple_shapes_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return input_output_alias.OutputHasAlias(output_index) &&
|
||||
input_output_alias.GetAliasedParameter(output_index).value().kind ==
|
||||
xla::HloInputOutputAliasConfig::kUserAlias;
|
||||
return input_output_alias.OutputHasAlias(output_index);
|
||||
}
|
||||
|
||||
// Returns an aliased tensor if it exists, nullptr otherwise.
|
||||
|
@ -424,9 +424,8 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
|
||||
alias.param_number,
|
||||
alias.param_index.ToString().c_str());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(config.SetUpAlias(
|
||||
alias.output_index, alias.param_number, alias.param_index,
|
||||
HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
|
||||
alias.param_index));
|
||||
}
|
||||
*module->mutable_input_output_alias() = config.ToProto();
|
||||
return Status::OK();
|
||||
|
@ -492,8 +492,7 @@ TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias(
|
||||
{}, 0, {}, HloInputOutputAliasConfig::kUserAlias));
|
||||
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
|
||||
|
||||
auto buffers = RunBufferAssignment(module.get());
|
||||
|
||||
|
@ -1586,11 +1586,9 @@ TEST_F(CopyInsertionTest, CrossingParameters) {
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 4);
|
||||
@ -1621,11 +1619,9 @@ TEST_F(CopyInsertionTest, ParametersAliasing) {
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
@ -1689,8 +1685,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
@ -1731,8 +1726,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
@ -1773,8 +1767,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
InsertCopies(module.get());
|
||||
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
@ -2505,11 +2498,11 @@ ENTRY entry_computation {
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{3},
|
||||
/*param_number=*/1,
|
||||
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
|
||||
InsertCopies(module.get());
|
||||
|
||||
@ -2532,7 +2525,7 @@ ENTRY Entry {
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
@ -256,17 +256,15 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
||||
se::DeviceMemoryBase argument_buffer = owning->Release();
|
||||
*maybe_owning_memory = argument_buffer;
|
||||
result_buffer = argument_buffer;
|
||||
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
|
||||
// This is a user alias, so a must alias. The caller is giving us the
|
||||
// input buffer, but in case of error of the execute call, we should
|
||||
// not be releasing it as it contains valid data (for example, it is a
|
||||
// parameter which the user wants us to alias, in a gradient update
|
||||
// computation). So we store the index into the result in the aliased
|
||||
// vactor, which will be fed to the ExecutionOutput, which will be
|
||||
// using the indices to drop the addresses from its own
|
||||
// ScopedShapedBuffer result, if the ExecutionOutput is not committed.
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
// The caller is giving us the
|
||||
// input buffer, but in case of error of the execute call, we should
|
||||
// not be releasing it as it contains valid data (for example, it is a
|
||||
// parameter which the user wants us to alias, in a gradient update
|
||||
// computation). So we store the index into the result in the aliased
|
||||
// vactor, which will be fed to the ExecutionOutput, which will be
|
||||
// using the indices to drop the addresses from its own
|
||||
// ScopedShapedBuffer result, if the ExecutionOutput is not committed.
|
||||
result.AddAliasedIndex(index);
|
||||
} else {
|
||||
VLOG(3) << "Using copy-protection: aliasing is specified, but the "
|
||||
"buffer is not donated; allocating a fresh buffer";
|
||||
|
@ -37,12 +37,10 @@ StatusOr<bool> AliasPassthroughParams::Run(HloModule* module) {
|
||||
<< " in module " << module->name()
|
||||
<< " is passed-through to root tuple element " << i << ": "
|
||||
<< root->shape().ToString();
|
||||
// Use must-alias semantics (kUserAlias) for pass-through params.
|
||||
TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{i},
|
||||
/*param_number=*/root->operand(i)->parameter_number(),
|
||||
/*param_index=*/{},
|
||||
HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
used_params.insert(root->operand(i)->parameter_number());
|
||||
changed = true;
|
||||
}
|
||||
|
@ -38,12 +38,8 @@ TEST_F(AliasPassthroughParamsTest, AliasPassThroughParams) {
|
||||
EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).ValueOrDie());
|
||||
const auto& alias_config = module->input_output_alias_config();
|
||||
EXPECT_EQ(0, alias_config.GetAliasedParameter({0}).value().parameter_number);
|
||||
EXPECT_EQ(xla::HloInputOutputAliasConfig::kUserAlias,
|
||||
alias_config.GetAliasedParameter({0}).value().kind);
|
||||
EXPECT_FALSE(alias_config.OutputHasAlias({1}));
|
||||
EXPECT_EQ(1, alias_config.GetAliasedParameter({2}).value().parameter_number);
|
||||
EXPECT_EQ(xla::HloInputOutputAliasConfig::kUserAlias,
|
||||
alias_config.GetAliasedParameter({2}).value().kind);
|
||||
}
|
||||
|
||||
TEST_F(AliasPassthroughParamsTest, DoNotAliasPassThroughParamsMoreThanOnce) {
|
||||
|
@ -499,17 +499,15 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
se::DeviceMemoryBase argument_buffer = owning->Release();
|
||||
*maybe_owning_memory = argument_buffer;
|
||||
result_buffer = argument_buffer;
|
||||
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
|
||||
// This is a user alias, so a must alias. The caller is giving us the
|
||||
// input buffer, but in case of error from the execute call, we should
|
||||
// not be releasing it as it contains valid data (for example, it is a
|
||||
// parameter which the user wants us to alias, in a gradient update
|
||||
// computation). So we store the index into the result in the aliased
|
||||
// vector, which will be fed to the ExecutionOutput, which will use
|
||||
// the indices to drop the addresses from its own ScopedShapedBuffer
|
||||
// result, if the ExecutionOutput is not committed.
|
||||
result.AddAliasedIndex(index);
|
||||
}
|
||||
// The caller is giving us the
|
||||
// input buffer, but in case of error from the execute call, we should
|
||||
// not be releasing it as it contains valid data (for example, it is a
|
||||
// parameter which the user wants us to alias, in a gradient update
|
||||
// computation). So we store the index into the result in the aliased
|
||||
// vector, which will be fed to the ExecutionOutput, which will use
|
||||
// the indices to drop the addresses from its own ScopedShapedBuffer
|
||||
// result, if the ExecutionOutput is not committed.
|
||||
result.AddAliasedIndex(index);
|
||||
} else if (src_hlo->opcode() != HloOpcode::kParameter) {
|
||||
// The guard is above is not to insert copy-protection when aliasing
|
||||
// pass-through params, as we do not need to write into the output
|
||||
|
@ -284,18 +284,6 @@ message HloScheduleProto {
|
||||
}
|
||||
|
||||
message HloInputOutputAliasProto {
|
||||
enum Kind {
|
||||
// Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
|
||||
// behavior and missing has_*() APIs.
|
||||
UNDEFINED_ALIAS = 0;
|
||||
// An alias setup by the user as must alias. A use setting USER_ALIAS is
|
||||
// expecting the designed output to be dropped over the given input
|
||||
// parameter number+index.
|
||||
USER_ALIAS = 1;
|
||||
// An alias setup by the compiler as part of its optimizations.
|
||||
SYSTEM_ALIAS = 2;
|
||||
}
|
||||
|
||||
// The following proto describes a pair of aliased an input
|
||||
// (described by parameter number and a ShapeIndex of the parameter)
|
||||
// and an output (described by a ShapeIndex of the root
|
||||
@ -316,8 +304,8 @@ message HloInputOutputAliasProto {
|
||||
int64 parameter_number = 2;
|
||||
// ShapeIndex of the parameter instruction.
|
||||
repeated int64 parameter_shape_index = 3;
|
||||
// The kind of alias to be setup.
|
||||
Kind kind = 4;
|
||||
reserved 4;
|
||||
reserved "kind";
|
||||
}
|
||||
|
||||
repeated AliasEntryProto entries = 1;
|
||||
|
@ -241,16 +241,13 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
|
||||
// Cannot alias an output twice.
|
||||
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
@ -287,16 +284,13 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
|
||||
// Cannot alias an output twice.
|
||||
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
@ -378,11 +372,9 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
|
||||
|
||||
const HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
|
@ -26,10 +26,7 @@ bool HloInputOutputAliasConfig::OutputHasAlias(
|
||||
|
||||
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
||||
int64 param_number,
|
||||
const ShapeIndex& param_index,
|
||||
AliasKind kind) {
|
||||
TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias)
|
||||
<< kind;
|
||||
const ShapeIndex& param_index) {
|
||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
|
||||
<< "Trying to set up alias at " << output_index.ToString()
|
||||
<< " which is an invalid index for shape "
|
||||
@ -44,8 +41,7 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
||||
param_number, param_index.ToString(), output_index.ToString(),
|
||||
alias_.element(output_index)->parameter_number,
|
||||
alias_.element(output_index)->parameter_index.ToString());
|
||||
(*alias_.mutable_element(output_index)) =
|
||||
Alias(kind, param_number, param_index);
|
||||
(*alias_.mutable_element(output_index)) = Alias(param_number, param_index);
|
||||
VLOG(4) << "Set up alias between output index " << output_index.ToString()
|
||||
<< " and parameter " << param_index << " at index "
|
||||
<< param_index.ToString();
|
||||
@ -58,16 +54,6 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
|
||||
[&](const ShapeIndex& index, const absl::optional<Alias>& data) {
|
||||
if (data) {
|
||||
HloInputOutputAliasProto::AliasEntryProto entry;
|
||||
switch (data->kind) {
|
||||
case AliasKind::kUserAlias:
|
||||
entry.set_kind(HloInputOutputAliasProto::USER_ALIAS);
|
||||
break;
|
||||
case AliasKind::kSystemAlias:
|
||||
entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unknown alias kind " << data->kind;
|
||||
}
|
||||
for (int64 i : index) {
|
||||
entry.add_output_shape_index(i);
|
||||
}
|
||||
@ -91,14 +77,8 @@ StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
||||
int64 param_number = entry.parameter_number();
|
||||
ShapeIndex param_index(entry.parameter_shape_index().begin(),
|
||||
entry.parameter_shape_index().end());
|
||||
// Handle backward compatibility with existing protos, which only knew of
|
||||
// system aliases.
|
||||
AliasKind kind = AliasKind::kSystemAlias;
|
||||
if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) {
|
||||
kind = AliasKind::kUserAlias;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.SetUpAlias(output_index, param_number, param_index, kind));
|
||||
result.SetUpAlias(output_index, param_number, param_index));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -113,9 +93,9 @@ string HloInputOutputAliasConfig::ToString() const {
|
||||
|
||||
ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
|
||||
pieces.push_back(absl::StrFormat(
|
||||
" OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:",
|
||||
output_index.ToString(), AliasKindToString(alias.kind),
|
||||
alias.parameter_number, alias.parameter_index.ToString()));
|
||||
" OutputIndex %s is aliased with parameter %lld at %s:",
|
||||
output_index.ToString(), alias.parameter_number,
|
||||
alias.parameter_index.ToString()));
|
||||
});
|
||||
return absl::StrJoin(pieces, "\n");
|
||||
}
|
||||
@ -134,20 +114,6 @@ string HloInputOutputAliasConfig::ToShortString() const {
|
||||
return absl::StrJoin(pieces, ", ");
|
||||
}
|
||||
|
||||
absl::optional<HloInputOutputAliasConfig::AliasKind>
|
||||
HloInputOutputAliasConfig::ParameterAliasKind(
|
||||
int64 param_number, const ShapeIndex& param_index) const {
|
||||
absl::optional<AliasKind> kind;
|
||||
alias_.ForEachElement(
|
||||
[&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
|
||||
if (alias && alias->parameter_number == param_number &&
|
||||
alias->parameter_index == param_index) {
|
||||
kind = alias->kind;
|
||||
}
|
||||
});
|
||||
return kind;
|
||||
}
|
||||
|
||||
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
|
||||
int64 param_number, const ShapeIndex& param_index) const {
|
||||
absl::optional<ShapeIndex> output;
|
||||
|
@ -32,43 +32,19 @@ class HloModule;
|
||||
// parameter index in the entry computation.
|
||||
class HloInputOutputAliasConfig {
|
||||
public:
|
||||
// The kind of aliases which can be set. A kUserAlias is one setup at
|
||||
// compilation time by the user, and has to be respected. A kSystemAlias one
|
||||
// might be setup by the compiler, if it decides it is convenient to do so.
|
||||
enum AliasKind {
|
||||
kUserAlias,
|
||||
kSystemAlias,
|
||||
};
|
||||
|
||||
static std::string AliasKindToString(AliasKind kind) {
|
||||
switch (kind) {
|
||||
case kUserAlias:
|
||||
return "USER";
|
||||
case kSystemAlias:
|
||||
return "SYSTEM";
|
||||
}
|
||||
}
|
||||
|
||||
// Defines the alias information for a given output buffer. A given output
|
||||
// buffer shape index can refer only to one parameter+index.
|
||||
struct Alias {
|
||||
Alias(AliasKind kind, int64 parameter_number, ShapeIndex parameter_index)
|
||||
: kind(kind),
|
||||
parameter_number(parameter_number),
|
||||
Alias(int64 parameter_number, ShapeIndex parameter_index)
|
||||
: parameter_number(parameter_number),
|
||||
parameter_index(std::move(parameter_index)) {}
|
||||
|
||||
AliasKind kind;
|
||||
int64 parameter_number;
|
||||
ShapeIndex parameter_index;
|
||||
|
||||
std::string ToString() {
|
||||
if (kind == kUserAlias) {
|
||||
return absl::StrFormat("(%lld, %s)", parameter_number,
|
||||
parameter_index.ToString());
|
||||
}
|
||||
return absl::StrFormat("(%lld, %s, %s)", parameter_number,
|
||||
parameter_index.ToString(),
|
||||
AliasKindToString(kind));
|
||||
return absl::StrFormat("(%lld, %s)", parameter_number,
|
||||
parameter_index.ToString());
|
||||
}
|
||||
};
|
||||
|
||||
@ -82,19 +58,13 @@ class HloInputOutputAliasConfig {
|
||||
// Sets up alias config from `output_index` to `param_index` at
|
||||
// `param_number`.
|
||||
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index,
|
||||
AliasKind kind = AliasKind::kUserAlias);
|
||||
|
||||
// Returns the kind of alias for the given parameter number and parameter
|
||||
// index.
|
||||
absl::optional<AliasKind> ParameterAliasKind(
|
||||
int64 param_number, const ShapeIndex& param_index) const;
|
||||
const ShapeIndex& param_index);
|
||||
|
||||
// Returns true if the given parameter is aliased with one of the output
|
||||
// buffers.
|
||||
bool ParameterHasAlias(int64 param_number,
|
||||
const ShapeIndex& param_index) const {
|
||||
return ParameterAliasKind(param_number, param_index).has_value();
|
||||
return GetAliasedOutput(param_number, param_index).has_value();
|
||||
}
|
||||
|
||||
// Checks whether the provided output index has already been aliased.
|
||||
|
@ -86,8 +86,7 @@ ENTRY main {
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/1,
|
||||
/*param_index=*/{},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
|
||||
expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
|
||||
/*param_index=*/{}, config);
|
||||
@ -118,13 +117,11 @@ ENTRY main {
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{0},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{0}));
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0,
|
||||
/*param_index=*/{1},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{1}));
|
||||
|
||||
expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{0}, config);
|
||||
@ -157,13 +154,11 @@ ENTRY main {
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0,
|
||||
/*param_index=*/{},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
|
||||
ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOf(shape);
|
||||
@ -188,8 +183,7 @@ ENTRY main {
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{1}, /*param_number=*/0,
|
||||
/*param_index=*/{},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
|
||||
ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
|
||||
return ShapeUtil::ByteSizeOf(shape);
|
||||
@ -214,13 +208,11 @@ ENTRY main {
|
||||
|
||||
TF_ASSERT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/0,
|
||||
/*param_index=*/{},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
|
||||
ASSERT_IS_NOT_OK(config.SetUpAlias(
|
||||
/*output_index=*/{0}, /*param_number=*/1,
|
||||
/*param_index=*/{},
|
||||
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
|
||||
/*param_index=*/{}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -237,8 +237,7 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
|
||||
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
// Set up alias of the first parameter.
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
|
||||
{}, 0, {}, HloInputOutputAliasConfig::kUserAlias));
|
||||
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias({}, 0, {}));
|
||||
|
||||
HloSchedule schedule(module_.get());
|
||||
|
||||
|
@ -563,21 +563,8 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
|
||||
if (!ParseShapeIndex(¶m_idx)) {
|
||||
return false;
|
||||
}
|
||||
HloInputOutputAliasConfig::AliasKind alias_kind =
|
||||
HloInputOutputAliasConfig::kUserAlias;
|
||||
if (EatIfPresent(TokKind::kComma)) {
|
||||
std::string type;
|
||||
ParseName(&type);
|
||||
if (type == "SYSTEM") {
|
||||
alias_kind = HloInputOutputAliasConfig::kSystemAlias;
|
||||
} else if (type == "USER") {
|
||||
alias_kind = HloInputOutputAliasConfig::kUserAlias;
|
||||
} else {
|
||||
return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
|
||||
}
|
||||
}
|
||||
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
|
||||
std::forward_as_tuple(alias_kind, param_num, param_idx));
|
||||
std::forward_as_tuple(param_num, param_idx));
|
||||
if (!ParseToken(TokKind::kRparen, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
@ -627,9 +614,8 @@ bool HloParserImpl::ParseHloModule(HloModule* module) {
|
||||
if (aliasing_data) {
|
||||
HloInputOutputAliasConfig alias_config(module->result_shape());
|
||||
for (auto& p : *aliasing_data) {
|
||||
Status st =
|
||||
alias_config.SetUpAlias(p.first, p.second.parameter_number,
|
||||
p.second.parameter_index, p.second.kind);
|
||||
Status st = alias_config.SetUpAlias(p.first, p.second.parameter_number,
|
||||
p.second.parameter_index);
|
||||
if (!st.ok()) {
|
||||
return TokenError(st.error_message());
|
||||
}
|
||||
|
@ -2399,7 +2399,7 @@ ENTRY c2 {
|
||||
|
||||
TEST_F(HloParserTest, SimpleAliasing) {
|
||||
const string original = R"(
|
||||
HloModule Module, input_output_alias={ {0}: (0, {0}, USER), {1}: (0, {1}, USER) }
|
||||
HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) }
|
||||
|
||||
ENTRY entry {
|
||||
%p = (f32[], f32[]) parameter(0)
|
||||
@ -2539,22 +2539,6 @@ ENTRY entry {
|
||||
"expects integer");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, AliasingUnexpectedKind) {
|
||||
const string original = R"(
|
||||
HloModule Module, input_output_alias={ {0}: (0, {0}, UNKNOWN), {1}: (0, {1}, UNKNOWN) }
|
||||
|
||||
ENTRY entry {
|
||||
%p = (f32[], f32[]) parameter(0)
|
||||
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
|
||||
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
|
||||
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
|
||||
}
|
||||
)";
|
||||
ExpectHasSubstr(
|
||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
||||
"Unexpected aliasing kind");
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, MultipleRoots) {
|
||||
const string original = R"(HloModule multiple_roots:
|
||||
ENTRY consts {
|
||||
|
@ -3247,10 +3247,8 @@ TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
// Make input {0} alias with output {0} and input {1} alias with output {1}.
|
||||
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
|
||||
{0}, 0, {0}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
|
||||
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
|
||||
{1}, 0, {1}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
|
||||
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
|
||||
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
|
||||
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
|
@ -75,9 +75,8 @@ StatusOr<bool> OptimizeInputOutputBufferAlias::Build(
|
||||
const ShapeIndex& output_index = index;
|
||||
if (!alias_config->ParameterHasAlias(0, input_index) &&
|
||||
!alias_config->OutputHasAlias(output_index)) {
|
||||
TF_RETURN_IF_ERROR(alias_config->SetUpAlias(
|
||||
output_index, 0, input_index,
|
||||
HloInputOutputAliasConfig::AliasKind::kSystemAlias));
|
||||
TF_RETURN_IF_ERROR(
|
||||
alias_config->SetUpAlias(output_index, 0, input_index));
|
||||
}
|
||||
entry.used = true;
|
||||
break;
|
||||
|
@ -303,11 +303,8 @@ Status RebuildOutputAliases(
|
||||
[&](const xla::ShapeIndex& output_index,
|
||||
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
|
||||
TF_RET_CHECK(alias.parameter_number < input_tuples.size());
|
||||
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
|
||||
? output_tuple->AliasBufferFrom(
|
||||
*input_tuples[alias.parameter_number],
|
||||
alias.parameter_index, output_index)
|
||||
: Status::OK();
|
||||
return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
|
||||
alias.parameter_index, output_index);
|
||||
};
|
||||
return input_output_alias.ForEachAliasWithStatus(alias_function);
|
||||
}
|
||||
@ -332,17 +329,7 @@ xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
|
||||
for (int64 i = 0; i < input_tuples.size(); ++i) {
|
||||
auto alias_checker =
|
||||
[&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
|
||||
// Only the buffers which the caller explicitly marked as aliased
|
||||
// (kUserAlias), should create aliases.
|
||||
// The XLA compiler might create opportunistic aliases (kSystemAlias)
|
||||
// which need a different handling. With a system alias we know that XLA
|
||||
// is going to reuse a given input parameter buffer for a given output, so
|
||||
// unless it is known at call site that the input buffer has no more uses,
|
||||
// a copy needs to be made at call site. With user specified alias the
|
||||
// caller tells us that he expects a given output to land over the buffers
|
||||
// of a given parametter.
|
||||
if (input_output_alias.ParameterAliasKind(i, index) ==
|
||||
xla::HloInputOutputAliasConfig::AliasKind::kUserAlias) {
|
||||
if (input_output_alias.ParameterHasAlias(i, index)) {
|
||||
TF_RET_CHECK(!is_dynamic(i));
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user