[Resubmit] If an input-output pair is configured to be must-alias(off by default), they must be aliased at runtime.
PiperOrigin-RevId: 325503193 Change-Id: Ida4e46531052c40eebce5f0dff4c50914cc1f3f4
This commit is contained in:
parent
247e9bd050
commit
5296ad4ffd
@ -524,7 +524,7 @@ TEST(CompileGraphToXlaHlo, Resources) {
|
|||||||
ASSERT_TRUE(status_or_hlo_module.ok());
|
ASSERT_TRUE(status_or_hlo_module.ok());
|
||||||
|
|
||||||
constexpr char expected_hlo_module_string[] =
|
constexpr char expected_hlo_module_string[] =
|
||||||
R"(HloModule main.4, input_output_alias={ {0}: 1 }
|
R"(HloModule main.4, input_output_alias={ {0}: (1, {}, may_alias) }
|
||||||
|
|
||||||
ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) {
|
ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) {
|
||||||
%Arg_1.2 = f32[2]{0} parameter(1)
|
%Arg_1.2 = f32[2]{0} parameter(1)
|
||||||
|
|||||||
@ -446,7 +446,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
|
|||||||
alias.param_index.ToString().c_str());
|
alias.param_index.ToString().c_str());
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
|
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
|
||||||
alias.param_index));
|
alias.param_index, alias.kind));
|
||||||
}
|
}
|
||||||
*module->mutable_input_output_alias() = config.ToProto();
|
*module->mutable_input_output_alias() = config.ToProto();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|||||||
@ -32,6 +32,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
|
#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
@ -349,12 +350,16 @@ class XlaBuilder {
|
|||||||
// not available until the computation is built, and eventual error in the
|
// not available until the computation is built, and eventual error in the
|
||||||
// arguments of this API will be detected only at computation Build() time.
|
// arguments of this API will be detected only at computation Build() time.
|
||||||
//
|
//
|
||||||
// Note: Aliasing API is 'may-alias' and only donated buffer at runtime will
|
// Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
|
||||||
// be aliased with output. If a buffer is not donated at runtime, a copy will
|
// and only donated buffer at runtime will be aliased with output. If a buffer
|
||||||
// be inserted by XLA to prevent buffer clobbering.
|
// is not donated at runtime, a copy will be inserted by XLA to prevent buffer
|
||||||
|
// clobbering.
|
||||||
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||||
const ShapeIndex& param_index) {
|
const ShapeIndex& param_index,
|
||||||
input_output_aliases_.push_back({output_index, param_number, param_index});
|
HloInputOutputAliasConfig::AliasKind kind =
|
||||||
|
HloInputOutputAliasConfig::AliasKind::kMayAlias) {
|
||||||
|
input_output_aliases_.push_back(
|
||||||
|
{output_index, param_number, param_index, kind});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Describes an input/output alias as inserted by the SetUpAlias() API.
|
// Describes an input/output alias as inserted by the SetUpAlias() API.
|
||||||
@ -365,6 +370,8 @@ class XlaBuilder {
|
|||||||
int64 param_number;
|
int64 param_number;
|
||||||
// Specifies the index of the aliased buffer in the parameter
|
// Specifies the index of the aliased buffer in the parameter
|
||||||
ShapeIndex param_index;
|
ShapeIndex param_index;
|
||||||
|
// Specifies if the alias is a must alias or may alias.
|
||||||
|
HloInputOutputAliasConfig::AliasKind kind;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
|
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
|
||||||
|
|||||||
@ -247,6 +247,12 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
|||||||
ExecutionInput& input = arguments[alias->parameter_number];
|
ExecutionInput& input = arguments[alias->parameter_number];
|
||||||
MaybeOwningDeviceMemory* maybe_owning_memory =
|
MaybeOwningDeviceMemory* maybe_owning_memory =
|
||||||
input.MutableBuffer(alias->parameter_index);
|
input.MutableBuffer(alias->parameter_index);
|
||||||
|
if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"An input was configured to be must-alias at "
|
||||||
|
"compile time but not donated at runtime: %s",
|
||||||
|
alias->ToString());
|
||||||
|
}
|
||||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||||
maybe_owning_memory->Release()) {
|
maybe_owning_memory->Release()) {
|
||||||
// If the caller passes the ownership of the device memory, reuse it
|
// If the caller passes the ownership of the device memory, reuse it
|
||||||
|
|||||||
@ -480,6 +480,12 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
|||||||
ExecutionInput& input = arguments[alias->parameter_number];
|
ExecutionInput& input = arguments[alias->parameter_number];
|
||||||
MaybeOwningDeviceMemory* maybe_owning_memory =
|
MaybeOwningDeviceMemory* maybe_owning_memory =
|
||||||
input.MutableBuffer(alias->parameter_index);
|
input.MutableBuffer(alias->parameter_index);
|
||||||
|
if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"An input was configured to be must-alias at "
|
||||||
|
"compile time but not donated at runtime: %s",
|
||||||
|
alias->ToString());
|
||||||
|
}
|
||||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||||
maybe_owning_memory->Release()) {
|
maybe_owning_memory->Release()) {
|
||||||
// If the caller passes the ownership of the device memory, reuse it
|
// If the caller passes the ownership of the device memory, reuse it
|
||||||
|
|||||||
@ -283,6 +283,16 @@ message HloScheduleProto {
|
|||||||
map<int64, InstructionSequence> sequences = 1;
|
map<int64, InstructionSequence> sequences = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum Kind {
|
||||||
|
// Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
|
||||||
|
// behavior and missing has_*() APIs.
|
||||||
|
UNDEFINED_ALIAS = 0;
|
||||||
|
// The buffers may or may not alias at runtime.
|
||||||
|
MAY_ALIAS = 1;
|
||||||
|
// The buffers must alias at runtime.
|
||||||
|
MUST_ALIAS = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message HloInputOutputAliasProto {
|
message HloInputOutputAliasProto {
|
||||||
// The following proto describes a pair of aliased an input
|
// The following proto describes a pair of aliased an input
|
||||||
// (described by parameter number and a ShapeIndex of the parameter)
|
// (described by parameter number and a ShapeIndex of the parameter)
|
||||||
@ -304,8 +314,8 @@ message HloInputOutputAliasProto {
|
|||||||
int64 parameter_number = 2;
|
int64 parameter_number = 2;
|
||||||
// ShapeIndex of the parameter instruction.
|
// ShapeIndex of the parameter instruction.
|
||||||
repeated int64 parameter_shape_index = 3;
|
repeated int64 parameter_shape_index = 3;
|
||||||
reserved 4;
|
// The kind of alias to be setup.
|
||||||
reserved "kind";
|
Kind kind = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
repeated AliasEntryProto entries = 1;
|
repeated AliasEntryProto entries = 1;
|
||||||
|
|||||||
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -24,9 +25,10 @@ bool HloInputOutputAliasConfig::OutputHasAlias(
|
|||||||
return alias_.element(output_index).has_value();
|
return alias_.element(output_index).has_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
Status HloInputOutputAliasConfig::SetUpAlias(
|
||||||
int64 param_number,
|
const ShapeIndex& output_index, int64 param_number,
|
||||||
const ShapeIndex& param_index) {
|
const ShapeIndex& param_index,
|
||||||
|
HloInputOutputAliasConfig::AliasKind must_alias) {
|
||||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
|
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
|
||||||
<< "Trying to set up alias at " << output_index.ToString()
|
<< "Trying to set up alias at " << output_index.ToString()
|
||||||
<< " which is an invalid index for shape "
|
<< " which is an invalid index for shape "
|
||||||
@ -41,7 +43,8 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
|||||||
param_number, param_index.ToString(), output_index.ToString(),
|
param_number, param_index.ToString(), output_index.ToString(),
|
||||||
alias_.element(output_index)->parameter_number,
|
alias_.element(output_index)->parameter_number,
|
||||||
alias_.element(output_index)->parameter_index.ToString());
|
alias_.element(output_index)->parameter_index.ToString());
|
||||||
(*alias_.mutable_element(output_index)) = Alias(param_number, param_index);
|
(*alias_.mutable_element(output_index)) =
|
||||||
|
Alias(param_number, param_index, must_alias);
|
||||||
VLOG(4) << "Set up alias between output index " << output_index.ToString()
|
VLOG(4) << "Set up alias between output index " << output_index.ToString()
|
||||||
<< " and parameter " << param_index << " at index "
|
<< " and parameter " << param_index << " at index "
|
||||||
<< param_index.ToString();
|
<< param_index.ToString();
|
||||||
@ -61,6 +64,11 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
|
|||||||
for (int64 i : data->parameter_index) {
|
for (int64 i : data->parameter_index) {
|
||||||
entry.add_parameter_shape_index(i);
|
entry.add_parameter_shape_index(i);
|
||||||
}
|
}
|
||||||
|
if (data->must_alias()) {
|
||||||
|
entry.set_kind(Kind::MUST_ALIAS);
|
||||||
|
} else {
|
||||||
|
entry.set_kind(Kind::MAY_ALIAS);
|
||||||
|
}
|
||||||
result.add_entries()->Swap(&entry);
|
result.add_entries()->Swap(&entry);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -77,8 +85,9 @@ StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
|||||||
int64 param_number = entry.parameter_number();
|
int64 param_number = entry.parameter_number();
|
||||||
ShapeIndex param_index(entry.parameter_shape_index().begin(),
|
ShapeIndex param_index(entry.parameter_shape_index().begin(),
|
||||||
entry.parameter_shape_index().end());
|
entry.parameter_shape_index().end());
|
||||||
|
AliasKind kind = entry.kind() == Kind::MAY_ALIAS ? kMayAlias : kMustAlias;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
result.SetUpAlias(output_index, param_number, param_index));
|
result.SetUpAlias(output_index, param_number, param_index, kind));
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -93,9 +102,9 @@ string HloInputOutputAliasConfig::ToString() const {
|
|||||||
|
|
||||||
ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
|
ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
|
||||||
pieces.push_back(absl::StrFormat(
|
pieces.push_back(absl::StrFormat(
|
||||||
" OutputIndex %s is aliased with parameter %lld at %s:",
|
" OutputIndex %s is %saliased with parameter %lld at %s:",
|
||||||
output_index.ToString(), alias.parameter_number,
|
output_index.ToString(), alias.kind == kMustAlias ? "must-" : "may-",
|
||||||
alias.parameter_index.ToString()));
|
alias.parameter_number, alias.parameter_index.ToString()));
|
||||||
});
|
});
|
||||||
return absl::StrJoin(pieces, "\n");
|
return absl::StrJoin(pieces, "\n");
|
||||||
}
|
}
|
||||||
@ -112,6 +121,19 @@ string HloInputOutputAliasConfig::ToShortString() const {
|
|||||||
return absl::StrJoin(pieces, ", ");
|
return absl::StrJoin(pieces, ", ");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool HloInputOutputAliasConfig::ParameterMustAlias(
|
||||||
|
int64 param_number, const ShapeIndex& param_index) const {
|
||||||
|
bool result = false;
|
||||||
|
alias_.ForEachElement(
|
||||||
|
[&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
|
||||||
|
if (alias && alias->parameter_number == param_number &&
|
||||||
|
alias->parameter_index == param_index && alias->must_alias()) {
|
||||||
|
result = true;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
|
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
|
||||||
int64 param_number, const ShapeIndex& param_index) const {
|
int64 param_number, const ShapeIndex& param_index) const {
|
||||||
absl::optional<ShapeIndex> output;
|
absl::optional<ShapeIndex> output;
|
||||||
|
|||||||
@ -32,22 +32,32 @@ class HloModule;
|
|||||||
// parameter index in the entry computation.
|
// parameter index in the entry computation.
|
||||||
class HloInputOutputAliasConfig {
|
class HloInputOutputAliasConfig {
|
||||||
public:
|
public:
|
||||||
|
// The kind of aliases which can be set. A kMayAlias is one setup at
|
||||||
|
// compilation time by the user, and has to be respected. A kMustAlias one
|
||||||
|
// might be setup by the compiler, if it decides it is convenient to do so.
|
||||||
|
enum AliasKind {
|
||||||
|
kMayAlias,
|
||||||
|
kMustAlias,
|
||||||
|
};
|
||||||
// Defines the alias information for a given output buffer. A given output
|
// Defines the alias information for a given output buffer. A given output
|
||||||
// buffer shape index can refer only to one parameter+index.
|
// buffer shape index can refer only to one parameter+index.
|
||||||
struct Alias {
|
struct Alias {
|
||||||
Alias(int64 parameter_number, ShapeIndex parameter_index)
|
Alias(int64 parameter_number, ShapeIndex parameter_index,
|
||||||
|
AliasKind kind = kMayAlias)
|
||||||
: parameter_number(parameter_number),
|
: parameter_number(parameter_number),
|
||||||
parameter_index(std::move(parameter_index)) {}
|
parameter_index(std::move(parameter_index)),
|
||||||
|
kind(kind) {}
|
||||||
|
|
||||||
int64 parameter_number;
|
int64 parameter_number;
|
||||||
ShapeIndex parameter_index;
|
ShapeIndex parameter_index;
|
||||||
|
AliasKind kind;
|
||||||
|
|
||||||
|
bool must_alias() const { return kind == kMustAlias; }
|
||||||
|
|
||||||
std::string ToString() {
|
std::string ToString() {
|
||||||
if (parameter_index.empty()) {
|
return absl::StrFormat("(%lld, %s, %s)", parameter_number,
|
||||||
return absl::StrCat(parameter_number);
|
parameter_index.ToString(),
|
||||||
}
|
kind == kMustAlias ? "must_alias" : "may_alias");
|
||||||
return absl::StrFormat("(%lld, %s)", parameter_number,
|
|
||||||
parameter_index.ToString());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -61,7 +71,8 @@ class HloInputOutputAliasConfig {
|
|||||||
// Sets up alias config from `output_index` to `param_index` at
|
// Sets up alias config from `output_index` to `param_index` at
|
||||||
// `param_number`.
|
// `param_number`.
|
||||||
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||||
const ShapeIndex& param_index);
|
const ShapeIndex& param_index,
|
||||||
|
AliasKind must_alias = kMayAlias);
|
||||||
|
|
||||||
// Returns true if the given parameter is aliased with one of the output
|
// Returns true if the given parameter is aliased with one of the output
|
||||||
// buffers.
|
// buffers.
|
||||||
@ -92,6 +103,11 @@ class HloInputOutputAliasConfig {
|
|||||||
absl::optional<Alias> GetAliasedParameter(
|
absl::optional<Alias> GetAliasedParameter(
|
||||||
const ShapeIndex& output_index) const;
|
const ShapeIndex& output_index) const;
|
||||||
|
|
||||||
|
// Returns if the parameter at the given parameter number and parameter
|
||||||
|
// index must-alias with an output.
|
||||||
|
bool ParameterMustAlias(int64 param_number,
|
||||||
|
const ShapeIndex& param_index) const;
|
||||||
|
|
||||||
using AliasFn =
|
using AliasFn =
|
||||||
std::function<void(const ShapeIndex& output_index, const Alias&)>;
|
std::function<void(const ShapeIndex& output_index, const Alias&)>;
|
||||||
|
|
||||||
|
|||||||
@ -552,14 +552,6 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lexer_.GetKind() != TokKind::kLparen) {
|
|
||||||
// Short form: "{0}: 0", output index "{}" is assumed.
|
|
||||||
int64 param_num;
|
|
||||||
ParseInt64(¶m_num);
|
|
||||||
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
|
|
||||||
std::forward_as_tuple(param_num, ShapeIndex{}));
|
|
||||||
} else {
|
|
||||||
// Long form: "{0}: (0, {0})", output index is explicitly specified.
|
|
||||||
if (!ParseToken(TokKind::kLparen, errmsg)) {
|
if (!ParseToken(TokKind::kLparen, errmsg)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -572,12 +564,26 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
|
|||||||
if (!ParseShapeIndex(¶m_idx)) {
|
if (!ParseShapeIndex(¶m_idx)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloInputOutputAliasConfig::AliasKind alias_kind =
|
||||||
|
HloInputOutputAliasConfig::kMayAlias;
|
||||||
|
if (EatIfPresent(TokKind::kComma)) {
|
||||||
|
std::string type;
|
||||||
|
ParseName(&type);
|
||||||
|
if (type == "must-alias") {
|
||||||
|
alias_kind = HloInputOutputAliasConfig::kMustAlias;
|
||||||
|
} else if (type == "may-alias") {
|
||||||
|
alias_kind = HloInputOutputAliasConfig::kMayAlias;
|
||||||
|
} else {
|
||||||
|
return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
|
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
|
||||||
std::forward_as_tuple(param_num, param_idx));
|
std::forward_as_tuple(param_num, param_idx, alias_kind));
|
||||||
if (!ParseToken(TokKind::kRparen, errmsg)) {
|
if (!ParseToken(TokKind::kRparen, errmsg)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (!EatIfPresent(TokKind::kComma)) {
|
if (!EatIfPresent(TokKind::kComma)) {
|
||||||
break;
|
break;
|
||||||
@ -624,8 +630,9 @@ bool HloParserImpl::ParseHloModule(HloModule* module) {
|
|||||||
if (aliasing_data) {
|
if (aliasing_data) {
|
||||||
HloInputOutputAliasConfig alias_config(module->result_shape());
|
HloInputOutputAliasConfig alias_config(module->result_shape());
|
||||||
for (auto& p : *aliasing_data) {
|
for (auto& p : *aliasing_data) {
|
||||||
Status st = alias_config.SetUpAlias(p.first, p.second.parameter_number,
|
Status st =
|
||||||
p.second.parameter_index);
|
alias_config.SetUpAlias(p.first, p.second.parameter_number,
|
||||||
|
p.second.parameter_index, p.second.kind);
|
||||||
if (!st.ok()) {
|
if (!st.ok()) {
|
||||||
return TokenError(st.error_message());
|
return TokenError(st.error_message());
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2399,7 +2399,7 @@ ENTRY c2 {
|
|||||||
|
|
||||||
TEST_F(HloParserTest, SimpleAliasing) {
|
TEST_F(HloParserTest, SimpleAliasing) {
|
||||||
const string original = R"(
|
const string original = R"(
|
||||||
HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) }
|
HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) }
|
||||||
|
|
||||||
ENTRY entry {
|
ENTRY entry {
|
||||||
%p = (f32[], f32[]) parameter(0)
|
%p = (f32[], f32[]) parameter(0)
|
||||||
@ -2413,42 +2413,13 @@ ENTRY entry {
|
|||||||
std::unique_ptr<HloModule> parsed_module = module.ConsumeValueOrDie();
|
std::unique_ptr<HloModule> parsed_module = module.ConsumeValueOrDie();
|
||||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
|
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
|
||||||
ShapeIndex{0});
|
ShapeIndex{0});
|
||||||
|
|
||||||
|
EXPECT_TRUE(
|
||||||
|
parsed_module->input_output_alias_config().ParameterMustAlias(0, {0}));
|
||||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
|
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
|
||||||
ShapeIndex{1});
|
ShapeIndex{1});
|
||||||
}
|
EXPECT_FALSE(
|
||||||
|
parsed_module->input_output_alias_config().ParameterMustAlias(0, {1}));
|
||||||
TEST_F(HloParserTest, SimpleAliasingShortForm) {
|
|
||||||
const string original = R"(
|
|
||||||
HloModule Module, input_output_alias={ {0}: 0, {1}: 1 }
|
|
||||||
|
|
||||||
ENTRY entry {
|
|
||||||
%p0 = f32[] parameter(0)
|
|
||||||
%p1 = f32[] parameter(1)
|
|
||||||
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
auto module = ParseAndReturnVerifiedModule(original);
|
|
||||||
TF_ASSERT_OK(module.status());
|
|
||||||
std::unique_ptr<HloModule> parsed_module = module.ConsumeValueOrDie();
|
|
||||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {}),
|
|
||||||
ShapeIndex{0});
|
|
||||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(1, {}),
|
|
||||||
ShapeIndex{1});
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(HloParserTest, SimpleAliasingShortFormError) {
|
|
||||||
const string original = R"(
|
|
||||||
HloModule Module, input_output_alias={ {0}: A, {1}: 1 }
|
|
||||||
|
|
||||||
ENTRY entry {
|
|
||||||
%p0 = f32[] parameter(0)
|
|
||||||
%p1 = f32[] parameter(1)
|
|
||||||
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
ExpectHasSubstr(
|
|
||||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
|
||||||
"expects integer");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloParserTest, NestedAliasing) {
|
TEST_F(HloParserTest, NestedAliasing) {
|
||||||
|
|||||||
@ -61,7 +61,7 @@ class BufferDonationTest : public HloTestBase {
|
|||||||
absl::Span<Literal const> argument_literals,
|
absl::Span<Literal const> argument_literals,
|
||||||
absl::Span<bool const> donate_arguments,
|
absl::Span<bool const> donate_arguments,
|
||||||
absl::Span<bool const> expected_runtime_aliasing,
|
absl::Span<bool const> expected_runtime_aliasing,
|
||||||
const Literal& expected) {
|
const Literal& expected, std::string expected_failure = "") {
|
||||||
// Create a copy of the output shape because the HLO module is std::moved
|
// Create a copy of the output shape because the HLO module is std::moved
|
||||||
// into the compiler and may be deallocated.
|
// into the compiler and may be deallocated.
|
||||||
const Shape output_shape = hlo_module->result_shape();
|
const Shape output_shape = hlo_module->result_shape();
|
||||||
@ -123,10 +123,19 @@ class BufferDonationTest : public HloTestBase {
|
|||||||
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
|
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
StatusOr<ExecutionOutput> output_status =
|
||||||
ExecutionOutput output,
|
|
||||||
executable->ExecuteAsyncOnStream(&service_run_options, std::move(args),
|
executable->ExecuteAsyncOnStream(&service_run_options, std::move(args),
|
||||||
/*hlo_execution_profile=*/nullptr));
|
/*hlo_execution_profile=*/nullptr);
|
||||||
|
if (!expected_failure.empty()) {
|
||||||
|
ASSERT_FALSE(output_status.ok());
|
||||||
|
ASSERT_TRUE(absl::StrContains(output_status.status().error_message(),
|
||||||
|
expected_failure))
|
||||||
|
<< "got: \n"
|
||||||
|
<< output_status.status().error_message() << " \nvs want\n"
|
||||||
|
<< expected_failure;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
ExecutionOutput output = output_status.ConsumeValueOrDie();
|
||||||
|
|
||||||
se::DeviceMemoryBase result_root_buffer = output.Result().root_buffer();
|
se::DeviceMemoryBase result_root_buffer = output.Result().root_buffer();
|
||||||
LOG(INFO) << "result allocation = " << result_root_buffer.opaque()
|
LOG(INFO) << "result allocation = " << result_root_buffer.opaque()
|
||||||
@ -303,5 +312,37 @@ ENTRY entry {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(BufferDonationTest, TestMustAliasNotDonated) {
|
||||||
|
HloModuleConfig config;
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<VerifiedHloModule>> module =
|
||||||
|
ParseAndReturnVerifiedModule(R"(
|
||||||
|
HloModule module
|
||||||
|
|
||||||
|
ENTRY entry {
|
||||||
|
a = f32[] parameter(0)
|
||||||
|
b = f32[] parameter(1)
|
||||||
|
ROOT out = (f32[], f32[]) tuple(a, b)
|
||||||
|
}
|
||||||
|
)",
|
||||||
|
config);
|
||||||
|
|
||||||
|
TF_ASSERT_OK(module->get()->input_output_alias_config().SetUpAlias(
|
||||||
|
{0}, 0, {}, HloInputOutputAliasConfig::kMustAlias));
|
||||||
|
|
||||||
|
std::vector<Literal> args;
|
||||||
|
args.push_back(LiteralUtil::CreateR0<float>(0.1));
|
||||||
|
args.push_back(LiteralUtil::CreateR0<float>(0.2));
|
||||||
|
Literal expected = LiteralUtil::MakeTupleFromSlices(
|
||||||
|
{LiteralUtil::CreateR0<float>(0.1), LiteralUtil::CreateR0<float>(0.2)});
|
||||||
|
|
||||||
|
#ifndef XLA_TEST_BACKEND_INTERPRETER
|
||||||
|
RunAndCheck(std::move(*module), args,
|
||||||
|
/*donate_arguments=*/{false, false}, {true, false}, expected,
|
||||||
|
"An input was configured to be must-alias at "
|
||||||
|
"compile time but not donated at runtime:");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -62,6 +62,24 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse(
|
|||||||
<< " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
|
<< " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
|
||||||
Shape device_shape = HostShapeToDeviceShape(host_shape);
|
Shape device_shape = HostShapeToDeviceShape(host_shape);
|
||||||
|
|
||||||
|
TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus(
|
||||||
|
[&](const ShapeIndex& output_index,
|
||||||
|
absl::optional<HloInputOutputAliasConfig::Alias> alias) {
|
||||||
|
if (alias && alias->must_alias()) {
|
||||||
|
VLOG(1) << alias->ToString();
|
||||||
|
const MaybeOwningDeviceMemory& original_input =
|
||||||
|
(*arguments)[alias->parameter_number].Buffers().element(
|
||||||
|
alias->parameter_index);
|
||||||
|
if (!original_input.HasOwnership()) {
|
||||||
|
return InvalidArgument(
|
||||||
|
"An input was configured to be must-alias at "
|
||||||
|
"compile time but not donated at runtime: %s",
|
||||||
|
alias->ToString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
|
||||||
if (VLOG_IS_ON(3)) {
|
if (VLOG_IS_ON(3)) {
|
||||||
VLOG(3) << "AllocateOutputMemoryWithInputReuse, device = " << device_ordinal
|
VLOG(3) << "AllocateOutputMemoryWithInputReuse, device = " << device_ordinal
|
||||||
<< " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
|
<< " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user