From aa7ff6aa28977826e7acae379e82da22482b2bf2 Mon Sep 17 00:00:00 2001
From: George Karpenkov <cheshire@google.com>
Date: Fri, 19 Jun 2020 18:16:58 -0700
Subject: [PATCH] [TF2XLA] Set up aliasing for resource variables even when not
 returning a tuple

PiperOrigin-RevId: 317414582
Change-Id: I45cd1f314331cb86a0257e7b7cf9d0639be84e99
---
 tensorflow/compiler/tf2xla/xla_compiler.cc | 41 +++++++++++++++-------
 1 file changed, 29 insertions(+), 12 deletions(-)

diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 1cf3e10b774..c1aef3ff690 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -268,6 +268,7 @@ Status BuildComputation(
               return a->arg_num() < b->arg_num();
             });
 
+  std::vector<xla::XlaBuilder::InputOutputAlias> aliases;
   for (const XlaResource* resource : arg_resources) {
     DCHECK_LT(resource->arg_num(), args.size());
     const XlaCompiler::Argument& arg = args[resource->arg_num()];
@@ -289,20 +290,19 @@ Status BuildComputation(
       update.type = resource->type();
       update.shape = resource->shape();
       update.modified = modified;
-      if (is_entry_computation && always_return_tuple &&
+      if (is_entry_computation &&
           arg.resource_kind != XlaResource::kTensorArray &&
           alias_resource_update) {
         // Assuming tuple arg and results are used.
-        int64 output_index = elems.size();
-        if (use_tuple_arg) {
-          builder->SetUpAlias(/*output_index=*/{output_index},
-                              /*param_number=*/0,
-                              /*param_index=*/{update.input_index});
-        } else {
-          builder->SetUpAlias(/*output_index=*/{output_index},
-                              /*param_number=*/update.input_index,
-                              /*param_index=*/{});
-        }
+        xla::ShapeIndex param_index =
+            use_tuple_arg ? xla::ShapeIndex({update.input_index})
+                          : xla::ShapeIndex{};
+        int param_number = use_tuple_arg ? 0 : update.input_index;
+        int64 output_index_num = elems.size();
+        xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num});
+        VLOG(3) << "Storing alias: " << output_index.ToString() << ": ("
+                << param_number << ", " << param_index.ToString() << ")";
+        aliases.push_back({output_index, param_number, param_index});
       }
       for (const auto& grad : resource->tensor_array_gradients()) {
         update.tensor_array_gradients_accessed.insert(grad.first);
@@ -381,8 +381,25 @@ Status BuildComputation(
     xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding);
     tuple = xla::Tuple(builder, elems);
   }
-  if (!always_return_tuple && elems.size() == 1) {
+  bool returns_tuple = always_return_tuple || elems.size() != 1;
+  VLOG(3) << "Computation returns a tuple=" << returns_tuple;
+  if (!returns_tuple) {
     xla::GetTupleElement(tuple, 0);
+
+    for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
+      if (alias.output_index == xla::ShapeIndex({0})) {
+        VLOG(3) << "For aliased parameter " << alias.param_number << ": "
+                << alias.param_index.ToString()
+                << " normalizing output_index from {0} to {}, as a scalar is "
+                   "returned from the cluster";
+        alias.output_index = xla::ShapeIndex({});
+      }
+    }
+  }
+
+  for (xla::XlaBuilder::InputOutputAlias& alias : aliases) {
+    builder->SetUpAlias(alias.output_index, alias.param_number,
+                        alias.param_index);
   }
 
   xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();