[XLA] Unify aliasing types

SYSTEM/USER alias distinction is not actually used, and knowing it at compile time does not bring any advantages, as we check actual aliasing at runtime in any case.

PiperOrigin-RevId: 320079893
Change-Id: I726cfe9dae0256904778a3bc3e501566aa026f9f
This commit is contained in:
George Karpenkov 2020-07-07 16:02:40 -07:00 committed by TensorFlower Gardener
parent 66668ec0ca
commit 0103bdb3cd
19 changed files with 75 additions and 235 deletions

View File

@ -244,9 +244,7 @@ static bool MustAliasOutput(
if (input_output_alias.shape().tuple_shapes_size() == 0) {
return false;
}
return input_output_alias.OutputHasAlias(output_index) &&
input_output_alias.GetAliasedParameter(output_index).value().kind ==
xla::HloInputOutputAliasConfig::kUserAlias;
return input_output_alias.OutputHasAlias(output_index);
}
// Returns an aliased tensor if it exists, nullptr otherwise.

View File

@ -424,9 +424,8 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
alias.param_number,
alias.param_index.ToString().c_str());
}
TF_RETURN_IF_ERROR(config.SetUpAlias(
alias.output_index, alias.param_number, alias.param_index,
HloInputOutputAliasConfig::AliasKind::kUserAlias));
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
alias.param_index));
}
*module->mutable_input_output_alias() = config.ToProto();
return Status::OK();

View File

@ -492,8 +492,7 @@ TEST_F(BufferAssignmentTest, AliasedParamCanBeReused) {
auto module = CreateNewVerifiedModule();
module->AddEntryComputation(builder.Build());
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias(
{}, 0, {}, HloInputOutputAliasConfig::kUserAlias));
TF_ASSERT_OK(module->input_output_alias_config().SetUpAlias({}, 0, {}));
auto buffers = RunBufferAssignment(module.get());

View File

@ -1586,11 +1586,9 @@ TEST_F(CopyInsertionTest, CrossingParameters) {
builder.AddInstruction(HloInstruction::CreateTuple({gte1, gte0}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 4);
@ -1621,11 +1619,9 @@ TEST_F(CopyInsertionTest, ParametersAliasing) {
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
@ -1689,8 +1685,7 @@ TEST_F(CopyInsertionTest, ParameterWithPartialAliasing) {
builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
InsertCopies(module.get());
EXPECT_THAT(module->entry_computation()->root_instruction(),
@ -1731,8 +1726,7 @@ TEST_F(CopyInsertionTest, ParameterAndParallelOpsWithPartialAliasing) {
builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
@ -1773,8 +1767,7 @@ TEST_F(CopyInsertionTest, ParameterAndOpsWithPartialAliasing) {
builder.AddInstruction(HloInstruction::CreateTuple({add, negate1}));
module->AddEntryComputation(builder.Build());
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
@ -2505,11 +2498,11 @@ ENTRY entry_computation {
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1},
/*param_number=*/0,
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{3},
/*param_number=*/1,
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
InsertCopies(module.get());
@ -2532,7 +2525,7 @@ ENTRY Entry {
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1},
/*param_number=*/0,
/*param_index=*/{}, HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}

View File

@ -256,17 +256,15 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
se::DeviceMemoryBase argument_buffer = owning->Release();
*maybe_owning_memory = argument_buffer;
result_buffer = argument_buffer;
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
// This is a user alias, so a must alias. The caller is giving us the
// input buffer, but in case of error of the execute call, we should
// not be releasing it as it contains valid data (for example, it is a
// parameter which the user wants us to alias, in a gradient update
// computation). So we store the index into the result in the aliased
// vactor, which will be fed to the ExecutionOutput, which will be
// using the indices to drop the addresses from its own
// ScopedShapedBuffer result, if the ExecutionOutput is not committed.
result.AddAliasedIndex(index);
}
// The caller is giving us the
// input buffer, but in case of error of the execute call, we should
// not be releasing it as it contains valid data (for example, it is a
// parameter which the user wants us to alias, in a gradient update
// computation). So we store the index into the result in the aliased
// vactor, which will be fed to the ExecutionOutput, which will be
// using the indices to drop the addresses from its own
// ScopedShapedBuffer result, if the ExecutionOutput is not committed.
result.AddAliasedIndex(index);
} else {
VLOG(3) << "Using copy-protection: aliasing is specified, but the "
"buffer is not donated; allocating a fresh buffer";

View File

@ -37,12 +37,10 @@ StatusOr<bool> AliasPassthroughParams::Run(HloModule* module) {
<< " in module " << module->name()
<< " is passed-through to root tuple element " << i << ": "
<< root->shape().ToString();
// Use must-alias semantics (kUserAlias) for pass-through params.
TF_RETURN_IF_ERROR(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{i},
/*param_number=*/root->operand(i)->parameter_number(),
/*param_index=*/{},
HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
used_params.insert(root->operand(i)->parameter_number());
changed = true;
}

View File

@ -38,12 +38,8 @@ TEST_F(AliasPassthroughParamsTest, AliasPassThroughParams) {
EXPECT_TRUE(AliasPassthroughParams().Run(module.get()).ValueOrDie());
const auto& alias_config = module->input_output_alias_config();
EXPECT_EQ(0, alias_config.GetAliasedParameter({0}).value().parameter_number);
EXPECT_EQ(xla::HloInputOutputAliasConfig::kUserAlias,
alias_config.GetAliasedParameter({0}).value().kind);
EXPECT_FALSE(alias_config.OutputHasAlias({1}));
EXPECT_EQ(1, alias_config.GetAliasedParameter({2}).value().parameter_number);
EXPECT_EQ(xla::HloInputOutputAliasConfig::kUserAlias,
alias_config.GetAliasedParameter({2}).value().kind);
}
TEST_F(AliasPassthroughParamsTest, DoNotAliasPassThroughParamsMoreThanOnce) {

View File

@ -499,17 +499,15 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
se::DeviceMemoryBase argument_buffer = owning->Release();
*maybe_owning_memory = argument_buffer;
result_buffer = argument_buffer;
if (alias->kind == HloInputOutputAliasConfig::kUserAlias) {
// This is a user alias, so a must alias. The caller is giving us the
// input buffer, but in case of error from the execute call, we should
// not be releasing it as it contains valid data (for example, it is a
// parameter which the user wants us to alias, in a gradient update
// computation). So we store the index into the result in the aliased
// vector, which will be fed to the ExecutionOutput, which will use
// the indices to drop the addresses from its own ScopedShapedBuffer
// result, if the ExecutionOutput is not committed.
result.AddAliasedIndex(index);
}
// The caller is giving us the
// input buffer, but in case of error from the execute call, we should
// not be releasing it as it contains valid data (for example, it is a
// parameter which the user wants us to alias, in a gradient update
// computation). So we store the index into the result in the aliased
// vector, which will be fed to the ExecutionOutput, which will use
// the indices to drop the addresses from its own ScopedShapedBuffer
// result, if the ExecutionOutput is not committed.
result.AddAliasedIndex(index);
} else if (src_hlo->opcode() != HloOpcode::kParameter) {
// The guard is above is not to insert copy-protection when aliasing
// pass-through params, as we do not need to write into the output

View File

@ -284,18 +284,6 @@ message HloScheduleProto {
}
message HloInputOutputAliasProto {
enum Kind {
// Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
// behavior and missing has_*() APIs.
UNDEFINED_ALIAS = 0;
// An alias setup by the user as must alias. A use setting USER_ALIAS is
// expecting the designed output to be dropped over the given input
// parameter number+index.
USER_ALIAS = 1;
// An alias setup by the compiler as part of its optimizations.
SYSTEM_ALIAS = 2;
}
// The following proto describes a pair of aliased an input
// (described by parameter number and a ShapeIndex of the parameter)
// and an output (described by a ShapeIndex of the root
@ -316,8 +304,8 @@ message HloInputOutputAliasProto {
int64 parameter_number = 2;
// ShapeIndex of the parameter instruction.
repeated int64 parameter_shape_index = 3;
// The kind of alias to be setup.
Kind kind = 4;
reserved 4;
reserved "kind";
}
repeated AliasEntryProto entries = 1;

View File

@ -241,16 +241,13 @@ TEST_F(HloAliasAnalysisTest, ParametersWithAliasing) {
SCOPED_TRACE(module_->ToString());
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
// Cannot alias an output twice.
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
const HloAliasAnalysis& analysis = RunAnalysis();
@ -287,16 +284,13 @@ TEST_F(HloAliasAnalysisTest, ParametersWithCrossAliasing) {
SCOPED_TRACE(module_->ToString());
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{1}));
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{0}));
// Cannot alias an output twice.
ASSERT_IS_NOT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
const HloAliasAnalysis& analysis = RunAnalysis();
@ -378,11 +372,9 @@ TEST_F(HloAliasAnalysisTest, InputOutputAliasingWithWhile) {
SCOPED_TRACE(module_->ToString());
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{0}, /*param_number=*/0, /*param_index=*/{0}));
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*output_index=*/{1}, /*param_number=*/0, /*param_index=*/{1}));
const HloAliasAnalysis& analysis = RunAnalysis();

View File

@ -26,10 +26,7 @@ bool HloInputOutputAliasConfig::OutputHasAlias(
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
int64 param_number,
const ShapeIndex& param_index,
AliasKind kind) {
TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias)
<< kind;
const ShapeIndex& param_index) {
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
<< "Trying to set up alias at " << output_index.ToString()
<< " which is an invalid index for shape "
@ -44,8 +41,7 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
param_number, param_index.ToString(), output_index.ToString(),
alias_.element(output_index)->parameter_number,
alias_.element(output_index)->parameter_index.ToString());
(*alias_.mutable_element(output_index)) =
Alias(kind, param_number, param_index);
(*alias_.mutable_element(output_index)) = Alias(param_number, param_index);
VLOG(4) << "Set up alias between output index " << output_index.ToString()
<< " and parameter " << param_index << " at index "
<< param_index.ToString();
@ -58,16 +54,6 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
[&](const ShapeIndex& index, const absl::optional<Alias>& data) {
if (data) {
HloInputOutputAliasProto::AliasEntryProto entry;
switch (data->kind) {
case AliasKind::kUserAlias:
entry.set_kind(HloInputOutputAliasProto::USER_ALIAS);
break;
case AliasKind::kSystemAlias:
entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS);
break;
default:
LOG(FATAL) << "Unknown alias kind " << data->kind;
}
for (int64 i : index) {
entry.add_output_shape_index(i);
}
@ -91,14 +77,8 @@ StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
int64 param_number = entry.parameter_number();
ShapeIndex param_index(entry.parameter_shape_index().begin(),
entry.parameter_shape_index().end());
// Handle backward compatibility with existing protos, which only knew of
// system aliases.
AliasKind kind = AliasKind::kSystemAlias;
if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) {
kind = AliasKind::kUserAlias;
}
TF_RETURN_IF_ERROR(
result.SetUpAlias(output_index, param_number, param_index, kind));
result.SetUpAlias(output_index, param_number, param_index));
}
return result;
}
@ -113,9 +93,9 @@ string HloInputOutputAliasConfig::ToString() const {
ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
pieces.push_back(absl::StrFormat(
" OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:",
output_index.ToString(), AliasKindToString(alias.kind),
alias.parameter_number, alias.parameter_index.ToString()));
" OutputIndex %s is aliased with parameter %lld at %s:",
output_index.ToString(), alias.parameter_number,
alias.parameter_index.ToString()));
});
return absl::StrJoin(pieces, "\n");
}
@ -134,20 +114,6 @@ string HloInputOutputAliasConfig::ToShortString() const {
return absl::StrJoin(pieces, ", ");
}
absl::optional<HloInputOutputAliasConfig::AliasKind>
HloInputOutputAliasConfig::ParameterAliasKind(
int64 param_number, const ShapeIndex& param_index) const {
absl::optional<AliasKind> kind;
alias_.ForEachElement(
[&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
if (alias && alias->parameter_number == param_number &&
alias->parameter_index == param_index) {
kind = alias->kind;
}
});
return kind;
}
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
int64 param_number, const ShapeIndex& param_index) const {
absl::optional<ShapeIndex> output;

View File

@ -32,43 +32,19 @@ class HloModule;
// parameter index in the entry computation.
class HloInputOutputAliasConfig {
public:
// The kind of aliases which can be set. A kUserAlias is one setup at
// compilation time by the user, and has to be respected. A kSystemAlias one
// might be setup by the compiler, if it decides it is convenient to do so.
enum AliasKind {
kUserAlias,
kSystemAlias,
};
static std::string AliasKindToString(AliasKind kind) {
switch (kind) {
case kUserAlias:
return "USER";
case kSystemAlias:
return "SYSTEM";
}
}
// Defines the alias information for a given output buffer. A given output
// buffer shape index can refer only to one parameter+index.
struct Alias {
Alias(AliasKind kind, int64 parameter_number, ShapeIndex parameter_index)
: kind(kind),
parameter_number(parameter_number),
Alias(int64 parameter_number, ShapeIndex parameter_index)
: parameter_number(parameter_number),
parameter_index(std::move(parameter_index)) {}
AliasKind kind;
int64 parameter_number;
ShapeIndex parameter_index;
std::string ToString() {
if (kind == kUserAlias) {
return absl::StrFormat("(%lld, %s)", parameter_number,
parameter_index.ToString());
}
return absl::StrFormat("(%lld, %s, %s)", parameter_number,
parameter_index.ToString(),
AliasKindToString(kind));
return absl::StrFormat("(%lld, %s)", parameter_number,
parameter_index.ToString());
}
};
@ -82,19 +58,13 @@ class HloInputOutputAliasConfig {
// Sets up alias config from `output_index` to `param_index` at
// `param_number`.
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
const ShapeIndex& param_index,
AliasKind kind = AliasKind::kUserAlias);
// Returns the kind of alias for the given parameter number and parameter
// index.
absl::optional<AliasKind> ParameterAliasKind(
int64 param_number, const ShapeIndex& param_index) const;
const ShapeIndex& param_index);
// Returns true if the given parameter is aliased with one of the output
// buffers.
bool ParameterHasAlias(int64 param_number,
const ShapeIndex& param_index) const {
return ParameterAliasKind(param_number, param_index).has_value();
return GetAliasedOutput(param_number, param_index).has_value();
}
// Checks whether the provided output index has already been aliased.

View File

@ -86,8 +86,7 @@ ENTRY main {
TF_ASSERT_OK(config.SetUpAlias(
/*output_index=*/{0}, /*param_number=*/1,
/*param_index=*/{},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
/*param_index=*/{}, config);
@ -118,13 +117,11 @@ ENTRY main {
TF_ASSERT_OK(config.SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{0},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{0}));
TF_ASSERT_OK(config.SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0,
/*param_index=*/{1},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{1}));
expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{0}, config);
@ -157,13 +154,11 @@ ENTRY main {
TF_ASSERT_OK(config.SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
TF_ASSERT_OK(config.SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0,
/*param_index=*/{},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape);
@ -188,8 +183,7 @@ ENTRY main {
TF_ASSERT_OK(config.SetUpAlias(
/*output_index=*/{1}, /*param_number=*/0,
/*param_index=*/{},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
return ShapeUtil::ByteSizeOf(shape);
@ -214,13 +208,11 @@ ENTRY main {
TF_ASSERT_OK(config.SetUpAlias(
/*output_index=*/{0}, /*param_number=*/0,
/*param_index=*/{},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
ASSERT_IS_NOT_OK(config.SetUpAlias(
/*output_index=*/{0}, /*param_number=*/1,
/*param_index=*/{},
/*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
/*param_index=*/{}));
}
} // namespace
} // namespace xla

View File

@ -237,8 +237,7 @@ TEST_F(HloLiveRangeTest, AliasedParameter) {
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, mul, paramY));
module_->AddEntryComputation(builder.Build());
// Set up alias of the first parameter.
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias(
{}, 0, {}, HloInputOutputAliasConfig::kUserAlias));
TF_ASSERT_OK(module_->input_output_alias_config().SetUpAlias({}, 0, {}));
HloSchedule schedule(module_.get());

View File

@ -563,21 +563,8 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
if (!ParseShapeIndex(&param_idx)) {
return false;
}
HloInputOutputAliasConfig::AliasKind alias_kind =
HloInputOutputAliasConfig::kUserAlias;
if (EatIfPresent(TokKind::kComma)) {
std::string type;
ParseName(&type);
if (type == "SYSTEM") {
alias_kind = HloInputOutputAliasConfig::kSystemAlias;
} else if (type == "USER") {
alias_kind = HloInputOutputAliasConfig::kUserAlias;
} else {
return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
}
}
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
std::forward_as_tuple(alias_kind, param_num, param_idx));
std::forward_as_tuple(param_num, param_idx));
if (!ParseToken(TokKind::kRparen, errmsg)) {
return false;
}
@ -627,9 +614,8 @@ bool HloParserImpl::ParseHloModule(HloModule* module) {
if (aliasing_data) {
HloInputOutputAliasConfig alias_config(module->result_shape());
for (auto& p : *aliasing_data) {
Status st =
alias_config.SetUpAlias(p.first, p.second.parameter_number,
p.second.parameter_index, p.second.kind);
Status st = alias_config.SetUpAlias(p.first, p.second.parameter_number,
p.second.parameter_index);
if (!st.ok()) {
return TokenError(st.error_message());
}

View File

@ -2399,7 +2399,7 @@ ENTRY c2 {
TEST_F(HloParserTest, SimpleAliasing) {
const string original = R"(
HloModule Module, input_output_alias={ {0}: (0, {0}, USER), {1}: (0, {1}, USER) }
HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
@ -2539,22 +2539,6 @@ ENTRY entry {
"expects integer");
}
TEST_F(HloParserTest, AliasingUnexpectedKind) {
const string original = R"(
HloModule Module, input_output_alias={ {0}: (0, {0}, UNKNOWN), {1}: (0, {1}, UNKNOWN) }
ENTRY entry {
%p = (f32[], f32[]) parameter(0)
%p0 = f32[] get-tuple-element((f32[], f32[]) %p), index=0
%p1 = f32[] get-tuple-element((f32[], f32[]) %p), index=1
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
}
)";
ExpectHasSubstr(
ParseAndReturnUnverifiedModule(original).status().error_message(),
"Unexpected aliasing kind");
}
TEST_F(HloParserTest, MultipleRoots) {
const string original = R"(HloModule multiple_roots:
ENTRY consts {

View File

@ -3247,10 +3247,8 @@ TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
TF_CHECK_OK(module->set_schedule(schedule));
// Make input {0} alias with output {0} and input {1} alias with output {1}.
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
{0}, 0, {0}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias(
{1}, 0, {1}, HloInputOutputAliasConfig::AliasKind::kSystemAlias));
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
AssignMemorySpace(module.get());

View File

@ -75,9 +75,8 @@ StatusOr<bool> OptimizeInputOutputBufferAlias::Build(
const ShapeIndex& output_index = index;
if (!alias_config->ParameterHasAlias(0, input_index) &&
!alias_config->OutputHasAlias(output_index)) {
TF_RETURN_IF_ERROR(alias_config->SetUpAlias(
output_index, 0, input_index,
HloInputOutputAliasConfig::AliasKind::kSystemAlias));
TF_RETURN_IF_ERROR(
alias_config->SetUpAlias(output_index, 0, input_index));
}
entry.used = true;
break;

View File

@ -303,11 +303,8 @@ Status RebuildOutputAliases(
[&](const xla::ShapeIndex& output_index,
const xla::HloInputOutputAliasConfig::Alias& alias) -> Status {
TF_RET_CHECK(alias.parameter_number < input_tuples.size());
return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias
? output_tuple->AliasBufferFrom(
*input_tuples[alias.parameter_number],
alias.parameter_index, output_index)
: Status::OK();
return output_tuple->AliasBufferFrom(*input_tuples[alias.parameter_number],
alias.parameter_index, output_index);
};
return input_output_alias.ForEachAliasWithStatus(alias_function);
}
@ -332,17 +329,7 @@ xla::StatusOr<std::vector<xla::ExecutionInput>> GetArgumentsBuffers(
for (int64 i = 0; i < input_tuples.size(); ++i) {
auto alias_checker =
[&](const xla::ShapeIndex& index) -> xla::StatusOr<bool> {
// Only the buffers which the caller explicitly marked as aliased
// (kUserAlias), should create aliases.
// The XLA compiler might create opportunistic aliases (kSystemAlias)
// which need a different handling. With a system alias we know that XLA
// is going to reuse a given input parameter buffer for a given output, so
// unless it is known at call site that the input buffer has no more uses,
// a copy needs to be made at call site. With user specified alias the
// caller tells us that he expects a given output to land over the buffers
// of a given parametter.
if (input_output_alias.ParameterAliasKind(i, index) ==
xla::HloInputOutputAliasConfig::AliasKind::kUserAlias) {
if (input_output_alias.ParameterHasAlias(i, index)) {
TF_RET_CHECK(!is_dynamic(i));
return true;
}