[TF2XLA] Set up aliasing for resource variables even when not returning a tuple
PiperOrigin-RevId: 317414582 Change-Id: I45cd1f314331cb86a0257e7b7cf9d0639be84e99
This commit is contained in:
parent
f0d0485b0d
commit
aa7ff6aa28
@ -268,6 +268,7 @@ Status BuildComputation(
|
||||
return a->arg_num() < b->arg_num();
|
||||
});
|
||||
|
||||
std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
|
||||
for (const XlaResource* resource : arg_resources) {
|
||||
DCHECK_LT(resource->arg_num(), args.size());
|
||||
const XlaCompiler::Argument& arg = args[resource->arg_num()];
|
||||
@ -289,20 +290,19 @@ Status BuildComputation(
|
||||
update.type = resource->type();
|
||||
update.shape = resource->shape();
|
||||
update.modified = modified;
|
||||
if (is_entry_computation && always_return_tuple &&
|
||||
if (is_entry_computation &&
|
||||
arg.resource_kind != XlaResource::kTensorArray &&
|
||||
alias_resource_update) {
|
||||
// Assuming tuple arg and results are used.
|
||||
int64 output_index = elems.size();
|
||||
if (use_tuple_arg) {
|
||||
builder->SetUpAlias(/*output_index=*/{output_index},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{update.input_index});
|
||||
} else {
|
||||
builder->SetUpAlias(/*output_index=*/{output_index},
|
||||
/*param_number=*/update.input_index,
|
||||
/*param_index=*/{});
|
||||
}
|
||||
xla::ShapeIndex param_index =
|
||||
use_tuple_arg ? xla::ShapeIndex({update.input_index})
|
||||
: xla::ShapeIndex{};
|
||||
int param_number = use_tuple_arg ? 0 : update.input_index;
|
||||
int64 output_index_num = elems.size();
|
||||
xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
|
||||
VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
|
||||
<< param_number << ", " << param_index.ToString() << ")";
|
||||
aliases.push_back({output_index, param_number, param_index});
|
||||
}
|
||||
for (const auto& grad : resource->tensor_array_gradients()) {
|
||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||
@ -381,8 +381,25 @@ Status BuildComputation(
|
||||
xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
|
||||
tuple = xla::Tuple(builder, elems);
|
||||
}
|
||||
if (!always_return_tuple && elems.size() == 1) {
|
||||
bool returns_tuple = always_return_tuple || elems.size() != 1;
|
||||
VLOG(3) << "Computation returns a tuple=" << returns_tuple;
|
||||
if (!returns_tuple) {
|
||||
xla::GetTupleElement(tuple, 0);
|
||||
|
||||
for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
|
||||
if (alias.output_index == xla::ShapeIndex({0})) {
|
||||
VLOG(3) << "For aliased parameter " << alias.param_number << ": "
|
||||
<< alias.param_index.ToString()
|
||||
<< " normalizing output_index from {0} to {}, as a scalar is "
|
||||
"returned from the cluster";
|
||||
alias.output_index = xla::ShapeIndex({});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
|
||||
builder->SetUpAlias(alias.output_index, alias.param_number,
|
||||
alias.param_index);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
|
||||
|
Loading…
Reference in New Issue
Block a user