[Resubmit] If an input-output pair is configured to be must-alias(off by default), they must be aliased at runtime.
PiperOrigin-RevId: 325503193 Change-Id: Ida4e46531052c40eebce5f0dff4c50914cc1f3f4
This commit is contained in:
parent
247e9bd050
commit
5296ad4ffd
@ -524,7 +524,7 @@ TEST(CompileGraphToXlaHlo, Resources) {
|
||||
ASSERT_TRUE(status_or_hlo_module.ok());
|
||||
|
||||
constexpr char expected_hlo_module_string[] =
|
||||
R"(HloModule main.4, input_output_alias={ {0}: 1 }
|
||||
R"(HloModule main.4, input_output_alias={ {0}: (1, {}, may_alias) }
|
||||
|
||||
ENTRY %main.4 (Arg_0.1: f32[2], Arg_1.2: f32[2]) -> (f32[2]) {
|
||||
%Arg_1.2 = f32[2]{0} parameter(1)
|
||||
|
@ -446,7 +446,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id,
|
||||
alias.param_index.ToString().c_str());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(config.SetUpAlias(alias.output_index, alias.param_number,
|
||||
alias.param_index));
|
||||
alias.param_index, alias.kind));
|
||||
}
|
||||
*module->mutable_input_output_alias() = config.ToProto();
|
||||
return Status::OK();
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -349,12 +350,16 @@ class XlaBuilder {
|
||||
// not available until the computation is built, and eventual error in the
|
||||
// arguments of this API will be detected only at computation Build() time.
|
||||
//
|
||||
// Note: Aliasing API is 'may-alias' and only donated buffer at runtime will
|
||||
// be aliased with output. If a buffer is not donated at runtime, a copy will
|
||||
// be inserted by XLA to prevent buffer clobbering.
|
||||
// Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
|
||||
// and only donated buffer at runtime will be aliased with output. If a buffer
|
||||
// is not donated at runtime, a copy will be inserted by XLA to prevent buffer
|
||||
// clobbering.
|
||||
void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index) {
|
||||
input_output_aliases_.push_back({output_index, param_number, param_index});
|
||||
const ShapeIndex& param_index,
|
||||
HloInputOutputAliasConfig::AliasKind kind =
|
||||
HloInputOutputAliasConfig::AliasKind::kMayAlias) {
|
||||
input_output_aliases_.push_back(
|
||||
{output_index, param_number, param_index, kind});
|
||||
}
|
||||
|
||||
// Describes an input/output alias as inserted by the SetUpAlias() API.
|
||||
@ -365,6 +370,8 @@ class XlaBuilder {
|
||||
int64 param_number;
|
||||
// Specifies the index of the aliased buffer in the parameter
|
||||
ShapeIndex param_index;
|
||||
// Specifies if the alias is a must alias or may alias.
|
||||
HloInputOutputAliasConfig::AliasKind kind;
|
||||
};
|
||||
|
||||
// Looks up the HloInstruction and sets the frontend attribute "attribute" to
|
||||
|
@ -247,6 +247,12 @@ StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(
|
||||
ExecutionInput& input = arguments[alias->parameter_number];
|
||||
MaybeOwningDeviceMemory* maybe_owning_memory =
|
||||
input.MutableBuffer(alias->parameter_index);
|
||||
if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
|
||||
return InvalidArgument(
|
||||
"An input was configured to be must-alias at "
|
||||
"compile time but not donated at runtime: %s",
|
||||
alias->ToString());
|
||||
}
|
||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||
maybe_owning_memory->Release()) {
|
||||
// If the caller passes the ownership of the device memory, reuse it
|
||||
|
@ -480,6 +480,12 @@ StatusOr<ExecutionOutput> GpuExecutable::ExecuteAsyncOnStream(
|
||||
ExecutionInput& input = arguments[alias->parameter_number];
|
||||
MaybeOwningDeviceMemory* maybe_owning_memory =
|
||||
input.MutableBuffer(alias->parameter_index);
|
||||
if (alias->must_alias() && !maybe_owning_memory->HasOwnership()) {
|
||||
return InvalidArgument(
|
||||
"An input was configured to be must-alias at "
|
||||
"compile time but not donated at runtime: %s",
|
||||
alias->ToString());
|
||||
}
|
||||
if (absl::optional<se::OwningDeviceMemory> owning =
|
||||
maybe_owning_memory->Release()) {
|
||||
// If the caller passes the ownership of the device memory, reuse it
|
||||
|
@ -283,6 +283,16 @@ message HloScheduleProto {
|
||||
map<int64, InstructionSequence> sequences = 1;
|
||||
}
|
||||
|
||||
enum Kind {
|
||||
// Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3
|
||||
// behavior and missing has_*() APIs.
|
||||
UNDEFINED_ALIAS = 0;
|
||||
// The buffers may or may not alias at runtime.
|
||||
MAY_ALIAS = 1;
|
||||
// The buffers must alias at runtime.
|
||||
MUST_ALIAS = 2;
|
||||
}
|
||||
|
||||
message HloInputOutputAliasProto {
|
||||
// The following proto describes a pair of aliased an input
|
||||
// (described by parameter number and a ShapeIndex of the parameter)
|
||||
@ -304,8 +314,8 @@ message HloInputOutputAliasProto {
|
||||
int64 parameter_number = 2;
|
||||
// ShapeIndex of the parameter instruction.
|
||||
repeated int64 parameter_shape_index = 3;
|
||||
reserved 4;
|
||||
reserved "kind";
|
||||
// The kind of alias to be setup.
|
||||
Kind kind = 4;
|
||||
}
|
||||
|
||||
repeated AliasEntryProto entries = 1;
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
namespace xla {
|
||||
@ -24,9 +25,10 @@ bool HloInputOutputAliasConfig::OutputHasAlias(
|
||||
return alias_.element(output_index).has_value();
|
||||
}
|
||||
|
||||
Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
||||
int64 param_number,
|
||||
const ShapeIndex& param_index) {
|
||||
Status HloInputOutputAliasConfig::SetUpAlias(
|
||||
const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index,
|
||||
HloInputOutputAliasConfig::AliasKind must_alias) {
|
||||
TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
|
||||
<< "Trying to set up alias at " << output_index.ToString()
|
||||
<< " which is an invalid index for shape "
|
||||
@ -41,7 +43,8 @@ Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
|
||||
param_number, param_index.ToString(), output_index.ToString(),
|
||||
alias_.element(output_index)->parameter_number,
|
||||
alias_.element(output_index)->parameter_index.ToString());
|
||||
(*alias_.mutable_element(output_index)) = Alias(param_number, param_index);
|
||||
(*alias_.mutable_element(output_index)) =
|
||||
Alias(param_number, param_index, must_alias);
|
||||
VLOG(4) << "Set up alias between output index " << output_index.ToString()
|
||||
<< " and parameter " << param_index << " at index "
|
||||
<< param_index.ToString();
|
||||
@ -61,6 +64,11 @@ HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
|
||||
for (int64 i : data->parameter_index) {
|
||||
entry.add_parameter_shape_index(i);
|
||||
}
|
||||
if (data->must_alias()) {
|
||||
entry.set_kind(Kind::MUST_ALIAS);
|
||||
} else {
|
||||
entry.set_kind(Kind::MAY_ALIAS);
|
||||
}
|
||||
result.add_entries()->Swap(&entry);
|
||||
}
|
||||
});
|
||||
@ -77,8 +85,9 @@ StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
|
||||
int64 param_number = entry.parameter_number();
|
||||
ShapeIndex param_index(entry.parameter_shape_index().begin(),
|
||||
entry.parameter_shape_index().end());
|
||||
AliasKind kind = entry.kind() == Kind::MAY_ALIAS ? kMayAlias : kMustAlias;
|
||||
TF_RETURN_IF_ERROR(
|
||||
result.SetUpAlias(output_index, param_number, param_index));
|
||||
result.SetUpAlias(output_index, param_number, param_index, kind));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -93,9 +102,9 @@ string HloInputOutputAliasConfig::ToString() const {
|
||||
|
||||
ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
|
||||
pieces.push_back(absl::StrFormat(
|
||||
" OutputIndex %s is aliased with parameter %lld at %s:",
|
||||
output_index.ToString(), alias.parameter_number,
|
||||
alias.parameter_index.ToString()));
|
||||
" OutputIndex %s is %saliased with parameter %lld at %s:",
|
||||
output_index.ToString(), alias.kind == kMustAlias ? "must-" : "may-",
|
||||
alias.parameter_number, alias.parameter_index.ToString()));
|
||||
});
|
||||
return absl::StrJoin(pieces, "\n");
|
||||
}
|
||||
@ -112,6 +121,19 @@ string HloInputOutputAliasConfig::ToShortString() const {
|
||||
return absl::StrJoin(pieces, ", ");
|
||||
}
|
||||
|
||||
bool HloInputOutputAliasConfig::ParameterMustAlias(
|
||||
int64 param_number, const ShapeIndex& param_index) const {
|
||||
bool result = false;
|
||||
alias_.ForEachElement(
|
||||
[&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
|
||||
if (alias && alias->parameter_number == param_number &&
|
||||
alias->parameter_index == param_index && alias->must_alias()) {
|
||||
result = true;
|
||||
}
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
|
||||
int64 param_number, const ShapeIndex& param_index) const {
|
||||
absl::optional<ShapeIndex> output;
|
||||
|
@ -32,22 +32,32 @@ class HloModule;
|
||||
// parameter index in the entry computation.
|
||||
class HloInputOutputAliasConfig {
|
||||
public:
|
||||
// The kind of aliases which can be set. A kMayAlias is one setup at
|
||||
// compilation time by the user, and has to be respected. A kMustAlias one
|
||||
// might be setup by the compiler, if it decides it is convenient to do so.
|
||||
enum AliasKind {
|
||||
kMayAlias,
|
||||
kMustAlias,
|
||||
};
|
||||
// Defines the alias information for a given output buffer. A given output
|
||||
// buffer shape index can refer only to one parameter+index.
|
||||
struct Alias {
|
||||
Alias(int64 parameter_number, ShapeIndex parameter_index)
|
||||
Alias(int64 parameter_number, ShapeIndex parameter_index,
|
||||
AliasKind kind = kMayAlias)
|
||||
: parameter_number(parameter_number),
|
||||
parameter_index(std::move(parameter_index)) {}
|
||||
parameter_index(std::move(parameter_index)),
|
||||
kind(kind) {}
|
||||
|
||||
int64 parameter_number;
|
||||
ShapeIndex parameter_index;
|
||||
AliasKind kind;
|
||||
|
||||
bool must_alias() const { return kind == kMustAlias; }
|
||||
|
||||
std::string ToString() {
|
||||
if (parameter_index.empty()) {
|
||||
return absl::StrCat(parameter_number);
|
||||
}
|
||||
return absl::StrFormat("(%lld, %s)", parameter_number,
|
||||
parameter_index.ToString());
|
||||
return absl::StrFormat("(%lld, %s, %s)", parameter_number,
|
||||
parameter_index.ToString(),
|
||||
kind == kMustAlias ? "must_alias" : "may_alias");
|
||||
}
|
||||
};
|
||||
|
||||
@ -61,7 +71,8 @@ class HloInputOutputAliasConfig {
|
||||
// Sets up alias config from `output_index` to `param_index` at
|
||||
// `param_number`.
|
||||
Status SetUpAlias(const ShapeIndex& output_index, int64 param_number,
|
||||
const ShapeIndex& param_index);
|
||||
const ShapeIndex& param_index,
|
||||
AliasKind must_alias = kMayAlias);
|
||||
|
||||
// Returns true if the given parameter is aliased with one of the output
|
||||
// buffers.
|
||||
@ -92,6 +103,11 @@ class HloInputOutputAliasConfig {
|
||||
absl::optional<Alias> GetAliasedParameter(
|
||||
const ShapeIndex& output_index) const;
|
||||
|
||||
// Returns if the parameter at the given parameter number and parameter
|
||||
// index must-alias with an output.
|
||||
bool ParameterMustAlias(int64 param_number,
|
||||
const ShapeIndex& param_index) const;
|
||||
|
||||
using AliasFn =
|
||||
std::function<void(const ShapeIndex& output_index, const Alias&)>;
|
||||
|
||||
|
@ -552,33 +552,39 @@ bool HloParserImpl::ParseAliasing(AliasingData* data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lexer_.GetKind() != TokKind::kLparen) {
|
||||
// Short form: "{0}: 0", output index "{}" is assumed.
|
||||
int64 param_num;
|
||||
ParseInt64(¶m_num);
|
||||
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
|
||||
std::forward_as_tuple(param_num, ShapeIndex{}));
|
||||
} else {
|
||||
// Long form: "{0}: (0, {0})", output index is explicitly specified.
|
||||
if (!ParseToken(TokKind::kLparen, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
int64 param_num;
|
||||
ParseInt64(¶m_num);
|
||||
if (!ParseToken(TokKind::kComma, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
ShapeIndex param_idx;
|
||||
if (!ParseShapeIndex(¶m_idx)) {
|
||||
return false;
|
||||
}
|
||||
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
|
||||
std::forward_as_tuple(param_num, param_idx));
|
||||
if (!ParseToken(TokKind::kRparen, errmsg)) {
|
||||
return false;
|
||||
if (!ParseToken(TokKind::kLparen, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
int64 param_num;
|
||||
ParseInt64(¶m_num);
|
||||
if (!ParseToken(TokKind::kComma, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
ShapeIndex param_idx;
|
||||
if (!ParseShapeIndex(¶m_idx)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
HloInputOutputAliasConfig::AliasKind alias_kind =
|
||||
HloInputOutputAliasConfig::kMayAlias;
|
||||
if (EatIfPresent(TokKind::kComma)) {
|
||||
std::string type;
|
||||
ParseName(&type);
|
||||
if (type == "must-alias") {
|
||||
alias_kind = HloInputOutputAliasConfig::kMustAlias;
|
||||
} else if (type == "may-alias") {
|
||||
alias_kind = HloInputOutputAliasConfig::kMayAlias;
|
||||
} else {
|
||||
return TokenError("Unexpected aliasing kind; expected SYSTEM or USER");
|
||||
}
|
||||
}
|
||||
|
||||
data->emplace(std::piecewise_construct, std::forward_as_tuple(out),
|
||||
std::forward_as_tuple(param_num, param_idx, alias_kind));
|
||||
if (!ParseToken(TokKind::kRparen, errmsg)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!EatIfPresent(TokKind::kComma)) {
|
||||
break;
|
||||
}
|
||||
@ -624,8 +630,9 @@ bool HloParserImpl::ParseHloModule(HloModule* module) {
|
||||
if (aliasing_data) {
|
||||
HloInputOutputAliasConfig alias_config(module->result_shape());
|
||||
for (auto& p : *aliasing_data) {
|
||||
Status st = alias_config.SetUpAlias(p.first, p.second.parameter_number,
|
||||
p.second.parameter_index);
|
||||
Status st =
|
||||
alias_config.SetUpAlias(p.first, p.second.parameter_number,
|
||||
p.second.parameter_index, p.second.kind);
|
||||
if (!st.ok()) {
|
||||
return TokenError(st.error_message());
|
||||
}
|
||||
|
@ -2399,7 +2399,7 @@ ENTRY c2 {
|
||||
|
||||
TEST_F(HloParserTest, SimpleAliasing) {
|
||||
const string original = R"(
|
||||
HloModule Module, input_output_alias={ {0}: (0, {0}), {1}: (0, {1}) }
|
||||
HloModule Module, input_output_alias={ {0}: (0, {0}, must-alias), {1}: (0, {1}) }
|
||||
|
||||
ENTRY entry {
|
||||
%p = (f32[], f32[]) parameter(0)
|
||||
@ -2413,42 +2413,13 @@ ENTRY entry {
|
||||
std::unique_ptr<HloModule> parsed_module = module.ConsumeValueOrDie();
|
||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {0}),
|
||||
ShapeIndex{0});
|
||||
|
||||
EXPECT_TRUE(
|
||||
parsed_module->input_output_alias_config().ParameterMustAlias(0, {0}));
|
||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {1}),
|
||||
ShapeIndex{1});
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, SimpleAliasingShortForm) {
|
||||
const string original = R"(
|
||||
HloModule Module, input_output_alias={ {0}: 0, {1}: 1 }
|
||||
|
||||
ENTRY entry {
|
||||
%p0 = f32[] parameter(0)
|
||||
%p1 = f32[] parameter(1)
|
||||
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnVerifiedModule(original);
|
||||
TF_ASSERT_OK(module.status());
|
||||
std::unique_ptr<HloModule> parsed_module = module.ConsumeValueOrDie();
|
||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(0, {}),
|
||||
ShapeIndex{0});
|
||||
EXPECT_EQ(parsed_module->input_output_alias_config().GetAliasedOutput(1, {}),
|
||||
ShapeIndex{1});
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, SimpleAliasingShortFormError) {
|
||||
const string original = R"(
|
||||
HloModule Module, input_output_alias={ {0}: A, {1}: 1 }
|
||||
|
||||
ENTRY entry {
|
||||
%p0 = f32[] parameter(0)
|
||||
%p1 = f32[] parameter(1)
|
||||
ROOT %out = (f32[], f32[]) tuple(%p0, %p1)
|
||||
}
|
||||
)";
|
||||
ExpectHasSubstr(
|
||||
ParseAndReturnUnverifiedModule(original).status().error_message(),
|
||||
"expects integer");
|
||||
EXPECT_FALSE(
|
||||
parsed_module->input_output_alias_config().ParameterMustAlias(0, {1}));
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, NestedAliasing) {
|
||||
|
@ -61,7 +61,7 @@ class BufferDonationTest : public HloTestBase {
|
||||
absl::Span<Literal const> argument_literals,
|
||||
absl::Span<bool const> donate_arguments,
|
||||
absl::Span<bool const> expected_runtime_aliasing,
|
||||
const Literal& expected) {
|
||||
const Literal& expected, std::string expected_failure = "") {
|
||||
// Create a copy of the output shape because the HLO module is std::moved
|
||||
// into the compiler and may be deallocated.
|
||||
const Shape output_shape = hlo_module->result_shape();
|
||||
@ -123,10 +123,19 @@ class BufferDonationTest : public HloTestBase {
|
||||
ExecutionInput(std::move(owned_buffers), argument_literal.shape()));
|
||||
}
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
ExecutionOutput output,
|
||||
StatusOr<ExecutionOutput> output_status =
|
||||
executable->ExecuteAsyncOnStream(&service_run_options, std::move(args),
|
||||
/*hlo_execution_profile=*/nullptr));
|
||||
/*hlo_execution_profile=*/nullptr);
|
||||
if (!expected_failure.empty()) {
|
||||
ASSERT_FALSE(output_status.ok());
|
||||
ASSERT_TRUE(absl::StrContains(output_status.status().error_message(),
|
||||
expected_failure))
|
||||
<< "got: \n"
|
||||
<< output_status.status().error_message() << " \nvs want\n"
|
||||
<< expected_failure;
|
||||
return;
|
||||
}
|
||||
ExecutionOutput output = output_status.ConsumeValueOrDie();
|
||||
|
||||
se::DeviceMemoryBase result_root_buffer = output.Result().root_buffer();
|
||||
LOG(INFO) << "result allocation = " << result_root_buffer.opaque()
|
||||
@ -303,5 +312,37 @@ ENTRY entry {
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST_F(BufferDonationTest, TestMustAliasNotDonated) {
|
||||
HloModuleConfig config;
|
||||
|
||||
StatusOr<std::unique_ptr<VerifiedHloModule>> module =
|
||||
ParseAndReturnVerifiedModule(R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
a = f32[] parameter(0)
|
||||
b = f32[] parameter(1)
|
||||
ROOT out = (f32[], f32[]) tuple(a, b)
|
||||
}
|
||||
)",
|
||||
config);
|
||||
|
||||
TF_ASSERT_OK(module->get()->input_output_alias_config().SetUpAlias(
|
||||
{0}, 0, {}, HloInputOutputAliasConfig::kMustAlias));
|
||||
|
||||
std::vector<Literal> args;
|
||||
args.push_back(LiteralUtil::CreateR0<float>(0.1));
|
||||
args.push_back(LiteralUtil::CreateR0<float>(0.2));
|
||||
Literal expected = LiteralUtil::MakeTupleFromSlices(
|
||||
{LiteralUtil::CreateR0<float>(0.1), LiteralUtil::CreateR0<float>(0.2)});
|
||||
|
||||
#ifndef XLA_TEST_BACKEND_INTERPRETER
|
||||
RunAndCheck(std::move(*module), args,
|
||||
/*donate_arguments=*/{false, false}, {true, false}, expected,
|
||||
"An input was configured to be must-alias at "
|
||||
"compile time but not donated at runtime:");
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -62,6 +62,24 @@ TpuExecutableInterface::AllocateOutputMemoryWithInputReuse(
|
||||
<< " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
|
||||
Shape device_shape = HostShapeToDeviceShape(host_shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(alias_config.ForEachAliasWithStatus(
|
||||
[&](const ShapeIndex& output_index,
|
||||
absl::optional<HloInputOutputAliasConfig::Alias> alias) {
|
||||
if (alias && alias->must_alias()) {
|
||||
VLOG(1) << alias->ToString();
|
||||
const MaybeOwningDeviceMemory& original_input =
|
||||
(*arguments)[alias->parameter_number].Buffers().element(
|
||||
alias->parameter_index);
|
||||
if (!original_input.HasOwnership()) {
|
||||
return InvalidArgument(
|
||||
"An input was configured to be must-alias at "
|
||||
"compile time but not donated at runtime: %s",
|
||||
alias->ToString());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << "AllocateOutputMemoryWithInputReuse, device = " << device_ordinal
|
||||
<< " host_shape = " << ShapeUtil::HumanStringWithLayout(host_shape);
|
||||
|
Loading…
x
Reference in New Issue
Block a user