[TF2XLA] Set up aliasing for resource variables even when not returning a tuple

PiperOrigin-RevId: 317414582
Change-Id: I45cd1f314331cb86a0257e7b7cf9d0639be84e99
This commit is contained in:
George Karpenkov 2020-06-19 18:16:58 -07:00 committed by TensorFlower Gardener
parent f0d0485b0d
commit aa7ff6aa28

View File

@ -268,6 +268,7 @@ Status BuildComputation(
return a->arg_num() < b->arg_num(); return a->arg_num() < b->arg_num();
}); });
std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
for (const XlaResource* resource : arg_resources) { for (const XlaResource* resource : arg_resources) {
DCHECK_LT(resource->arg_num(), args.size()); DCHECK_LT(resource->arg_num(), args.size());
const XlaCompiler::Argument& arg = args[resource->arg_num()]; const XlaCompiler::Argument& arg = args[resource->arg_num()];
@ -289,20 +290,19 @@ Status BuildComputation(
update.type = resource->type(); update.type = resource->type();
update.shape = resource->shape(); update.shape = resource->shape();
update.modified = modified; update.modified = modified;
if (is_entry_computation && always_return_tuple && if (is_entry_computation &&
arg.resource_kind != XlaResource::kTensorArray && arg.resource_kind != XlaResource::kTensorArray &&
alias_resource_update) { alias_resource_update) {
// Assuming tuple arg and results are used. // Assuming tuple arg and results are used.
int64 output_index = elems.size(); xla::ShapeIndex param_index =
if (use_tuple_arg) { use_tuple_arg ? xla::ShapeIndex({update.input_index})
builder->SetUpAlias(/*output_index=*/{output_index}, : xla::ShapeIndex{};
/*param_number=*/0, int param_number = use_tuple_arg ? 0 : update.input_index;
/*param_index=*/{update.input_index}); int64 output_index_num = elems.size();
} else { xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
builder->SetUpAlias(/*output_index=*/{output_index}, VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
/*param_number=*/update.input_index, << param_number << ", " << param_index.ToString() << ")";
/*param_index=*/{}); aliases.push_back({output_index, param_number, param_index});
}
} }
for (const auto& grad : resource->tensor_array_gradients()) { for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first); update.tensor_array_gradients_accessed.insert(grad.first);
@ -381,8 +381,25 @@ Status BuildComputation(
xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding); xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
tuple = xla::Tuple(builder, elems); tuple = xla::Tuple(builder, elems);
} }
if (!always_return_tuple && elems.size() == 1) { bool returns_tuple = always_return_tuple || elems.size() != 1;
VLOG(3) << "Computation returns a tuple=" << returns_tuple;
if (!returns_tuple) {
xla::GetTupleElement(tuple, 0); xla::GetTupleElement(tuple, 0);
for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
if (alias.output_index == xla::ShapeIndex({0})) {
VLOG(3) << "For aliased parameter " << alias.param_number << ": "
<< alias.param_index.ToString()
<< " normalizing output_index from {0} to {}, as a scalar is "
"returned from the cluster";
alias.output_index = xla::ShapeIndex({});
}
}
}
for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
builder->SetUpAlias(alias.output_index, alias.param_number,
alias.param_index);
} }
xla::StatusOr<xla::XlaComputation> computation_status = builder->Build(); xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();