[TF2XLA] Set up aliasing for resource variables even when not returning a tuple

PiperOrigin-RevId: 317414582
Change-Id: I45cd1f314331cb86a0257e7b7cf9d0639be84e99
This commit is contained in:
George Karpenkov 2020-06-19 18:16:58 -07:00 committed by TensorFlower Gardener
parent f0d0485b0d
commit aa7ff6aa28

View File

@ -268,6 +268,7 @@ Status BuildComputation(
return a->arg_num() < b->arg_num();
});
std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
for (const XlaResource* resource : arg_resources) {
DCHECK_LT(resource->arg_num(), args.size());
const XlaCompiler::Argument& arg = args[resource->arg_num()];
@ -289,20 +290,19 @@ Status BuildComputation(
update.type = resource->type();
update.shape = resource->shape();
update.modified = modified;
if (is_entry_computation && always_return_tuple &&
if (is_entry_computation &&
arg.resource_kind != XlaResource::kTensorArray &&
alias_resource_update) {
// Assuming tuple arg and results are used.
int64 output_index = elems.size();
if (use_tuple_arg) {
builder->SetUpAlias(/*output_index=*/{output_index},
/*param_number=*/0,
/*param_index=*/{update.input_index});
} else {
builder->SetUpAlias(/*output_index=*/{output_index},
/*param_number=*/update.input_index,
/*param_index=*/{});
}
xla::ShapeIndex param_index =
use_tuple_arg ? xla::ShapeIndex({update.input_index})
: xla::ShapeIndex{};
int param_number = use_tuple_arg ? 0 : update.input_index;
int64 output_index_num = elems.size();
xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
<< param_number << ", " << param_index.ToString() << ")";
aliases.push_back({output_index, param_number, param_index});
}
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
@ -381,8 +381,25 @@ Status BuildComputation(
xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
tuple = xla::Tuple(builder, elems);
}
if (!always_return_tuple && elems.size() == 1) {
bool returns_tuple = always_return_tuple || elems.size() != 1;
VLOG(3) << "Computation returns a tuple=" << returns_tuple;
if (!returns_tuple) {
xla::GetTupleElement(tuple, 0);
for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
if (alias.output_index == xla::ShapeIndex({0})) {
VLOG(3) << "For aliased parameter " << alias.param_number << ": "
<< alias.param_index.ToString()
<< " normalizing output_index from {0} to {}, as a scalar is "
"returned from the cluster";
alias.output_index = xla::ShapeIndex({});
}
}
}
for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
builder->SetUpAlias(alias.output_index, alias.param_number,
alias.param_index);
}
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();