From aa7ff6aa28977826e7acae379e82da22482b2bf2 Mon Sep 17 00:00:00 2001 From: George Karpenkov <cheshire@google.com> Date: Fri, 19 Jun 2020 18:16:58 -0700 Subject: [PATCH] [TF2XLA] Set up aliasing for resource variables even when not returning a tuple PiperOrigin-RevId: 317414582 Change-Id: I45cd1f314331cb86a0257e7b7cf9d0639be84e99 --- tensorflow/compiler/tf2xla/xla_compiler.cc | 41 +++++++++++++++------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 1cf3e10b774..c1aef3ff690 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -268,6 +268,7 @@ Status BuildComputation( return a->arg_num() < b->arg_num(); }); + std::vector<xla::XlaBuilder::InputOutputAlias> aliases; for (const XlaResource* resource : arg_resources) { DCHECK_LT(resource->arg_num(), args.size()); const XlaCompiler::Argument& arg = args[resource->arg_num()]; @@ -289,20 +290,19 @@ Status BuildComputation( update.type = resource->type(); update.shape = resource->shape(); update.modified = modified; - if (is_entry_computation && always_return_tuple && + if (is_entry_computation && arg.resource_kind != XlaResource::kTensorArray && alias_resource_update) { // Assuming tuple arg and results are used. - int64 output_index = elems.size(); - if (use_tuple_arg) { - builder->SetUpAlias(/*output_index=*/{output_index}, - /*param_number=*/0, - /*param_index=*/{update.input_index}); - } else { - builder->SetUpAlias(/*output_index=*/{output_index}, - /*param_number=*/update.input_index, - /*param_index=*/{}); - } + xla::ShapeIndex param_index = + use_tuple_arg ? xla::ShapeIndex({update.input_index}) + : xla::ShapeIndex{}; + int param_number = use_tuple_arg ? 0 : update.input_index; + int64 output_index_num = elems.size(); + xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num}); + VLOG(3) << "Storing alias: " << output_index.ToString() << ": (" + << param_number << ", " << param_index.ToString() << ")"; + aliases.push_back({output_index, param_number, param_index}); } for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); @@ -381,8 +381,25 @@ Status BuildComputation( xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding); tuple = xla::Tuple(builder, elems); } - if (!always_return_tuple && elems.size() == 1) { + bool returns_tuple = always_return_tuple || elems.size() != 1; + VLOG(3) << "Computation returns a tuple=" << returns_tuple; + if (!returns_tuple) { xla::GetTupleElement(tuple, 0); + + for (xla::XlaBuilder::InputOutputAlias& alias : aliases) { + if (alias.output_index == xla::ShapeIndex({0})) { + VLOG(3) << "For aliased parameter " << alias.param_number << ": " + << alias.param_index.ToString() + << " normalizing output_index from {0} to {}, as a scalar is " + "returned from the cluster"; + alias.output_index = xla::ShapeIndex({}); + } + } + } + + for (xla::XlaBuilder::InputOutputAlias& alias : aliases) { + builder->SetUpAlias(alias.output_index, alias.param_number, + alias.param_index); } xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();