diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 07ff323a3d7..1ca81c65038 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -386,7 +386,10 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
 
 tensorflow::Status BufferAssigner::AssignBuffersForComputation(
     const HloComputation* computation, bool is_thread_local,
-    const std::unordered_set<const HloInstruction*>* hlos_to_allocate,
+    const tensorflow::gtl::FlatSet<const HloInstruction*>* hlos_to_allocate,
+    const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
+    const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
+        colocated_allocations,
     BufferAssignment* assignment) {
   // Buffers are sorted and assigned to BufferAllocations in decreasing order of
   // size.
@@ -407,7 +410,7 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation(
 
   // Generate a post order sort of instructions for sorting of the
   // LogicalBuffers.
-  std::unordered_map<const HloInstruction*, int> post_order_position;
+  tensorflow::gtl::FlatMap<const HloInstruction*, int> post_order_position;
   int position = 0;
   for (auto* instruction : computation->MakeInstructionPostOrder()) {
     post_order_position.emplace(instruction, position);
@@ -445,7 +448,7 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation(
   std::vector<BufferAllocation::Index> allocation_indices;
   for (const auto* buffer : sorted_buffers) {
     VLOG(3) << "Assigning allocation to: " << buffer->ToString();
-    if (colocated_buffers_.find(buffer) != colocated_buffers_.end()) {
+    if (colocated_buffers.count(buffer) > 0) {
       // Colocated buffers are currently assigned in an earlier pass.
       continue;
     }
@@ -487,39 +490,6 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation(
       continue;
     }
 
-    if (buffer->instruction()->opcode() == HloOpcode::kCall &&
-        buffer->IsTopLevel()) {
-      // Assign the kCall instruction the same allocation as the root of the
-      // called computation. The points-to set of the root of the called
-      // computation must be unambigous so we know statically the allocation for
-      // the root.
-      //
-      // TODO(b/32491382): This is a hack. To properly handle this case
-      // points-to analysis, liveness analysis, and buffer assignment need to
-      // module-scope rather than computation-scope.
-      HloInstruction* call = buffer->instruction();
-      HloInstruction* computation_root = call->to_apply()->root_instruction();
-
-      // The buffer of the root of the called computation must be unambiguous.
-      const auto& root_points_to = assignment->GetPointsToSet(computation_root);
-      if (root_points_to.IsAmbiguous()) {
-        return Unimplemented(
-            "kCall of a computation with an ambiguous root points-to set");
-      }
-      CHECK_EQ(1, root_points_to.element(/*index=*/{}).size());
-      const LogicalBuffer* root_buffer =
-          root_points_to.element(/*index=*/{})[0];
-      BufferAllocation* root_allocation =
-          assignment->GetMutableAssignedAllocation(*root_buffer);
-
-      // Can't use MaybeAssignBuffer here because buffer liveness conservatively
-      // assumes buffers in different computations always interfere.
-      CHECK_GE(root_allocation->size(), buffer_size_(*buffer));
-      assignment->AddAssignment(*buffer, root_allocation,
-                                /*colocated_buffer=*/true);
-      continue;
-    }
-
     if (ShapeUtil::IsTuple(buffer->shape())) {
       // TODO(b/34669761): Don't reuse tuple buffers because the GPU backend
       // assumes longer buffer liveness than indicated by the analysis.
@@ -539,8 +509,7 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation(
              assignment->GetAllocations(operand, /*index=*/{})) {
           BufferAllocation* allocation =
               assignment->GetMutableAllocation(operand_allocation.index());
-          if (colocated_buffer_allocations_.find(allocation->index()) ==
-              colocated_buffer_allocations_.end()) {
+          if (colocated_allocations.count(allocation->index()) == 0) {
             // TODO(b/32491382) Colocated buffers are currently assigned in an
             // earlier pass, and so can break the "increasing allocation size"
             // invariant in this function (causing this CHECK to fail). However,
@@ -571,8 +540,7 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation(
         // Instructions are iterated in increasing buffer size, so any
         // previously create allocation must be large enough to hold this
         // instruction's output (with the exception of colocated buffers).
-        if (colocated_buffer_allocations_.find(allocation->index()) ==
-            colocated_buffer_allocations_.end()) {
+        if (colocated_allocations.count(allocation->index()) == 0) {
           // TODO(b/32491382) Colocated buffers are currently assigned in an
           // earlier pass, and so can break the "increasing allocation size"
           // invariant in this function (causing this CHECK to fail). However,
@@ -598,75 +566,147 @@ tensorflow::Status BufferAssigner::AssignBuffersForComputation(
   return tensorflow::Status::OK();
 }
 
-void BufferAssigner::AddBufferToColocatedBufferSet(
-    const HloInstruction* instruction, const ShapeIndex& index,
-    const TuplePointsToAnalysis& points_to_analysis,
-    BufferAssigner::ColocatedBufferSet* colocated_buffer_set) {
-  const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
-  // CopyInsertion ensures root points-to set is unambiguous and distinct.
-  CHECK(!points_to.IsAmbiguous());
-  CHECK(points_to.IsDistinct());
-  colocated_buffer_set->push_back(points_to.element(index)[0]);
-}
+// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
+// the invariant that all sets in 'colocated_buffer_sets' are disjoint.
+//
+// A practical example of when this is necessary is a chain of kCall ops:
+//   computation.entry
+//     %a = call() -> computation.1
+//   computation.1
+//     %b = call() -> computation.2
+//   computation.2
+//     %c = parameter()
+// This yields the logical sets {%a,%b} {%b,%c} {%c}, which need to be merged
+// into a single set {%a,%b,%c}
+void BufferAssigner::AddSetToColocatedBufferSets(
+    const std::vector<const LogicalBuffer*>& colocated_set,
+    std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
+  if (colocated_set.empty()) {
+    return;
+  }
 
-// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
-// in the same allocation (currently just supports kWhile).
-std::vector<BufferAssigner::ColocatedBufferSet>
-BufferAssigner::BuildColocatedBufferSets(
-    const HloModule* module, const TuplePointsToAnalysis& points_to_analysis) {
-  std::vector<ColocatedBufferSet> colocated_buffer_sets;
-  for (auto& computation : module->computations()) {
-    for (auto& instruction : computation->instructions()) {
-      if (instruction->opcode() != HloOpcode::kWhile) {
-        continue;
+  // Find existing sets that overlap with at least one buffer from the
+  // colocated_set.
+  std::vector<size_t> overlap_set_indices;
+  for (const LogicalBuffer* buffer : colocated_set) {
+    for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) {
+      if ((*colocated_buffer_sets)[index].count(buffer) > 0) {
+        overlap_set_indices.push_back(index);
+      }
+    }
+  }
+
+  // If there is no overlap with existing sets, create a new set.
+  if (overlap_set_indices.empty()) {
+    colocated_buffer_sets->emplace_back();
+    colocated_buffer_sets->back().insert(colocated_set.begin(),
+                                         colocated_set.end());
+    return;
+  }
+
+  // Merge all overlap sets and the colocated set into the first overlap set.
+  ColocatedBufferSet* first = &(*colocated_buffer_sets)[overlap_set_indices[0]];
+  for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
+    const ColocatedBufferSet& overlap_set =
+        (*colocated_buffer_sets)[overlap_set_indices[index]];
+    first->insert(overlap_set.begin(), overlap_set.end());
+  }
+  first->insert(colocated_set.begin(), colocated_set.end());
+
+  // Remove overlap sets that we just merged. The offset accounts for the fact
+  // that as elements are erased, the indices need to be adjusted. Keep in mind
+  // that overlap_set_indices is in increasing order.
+  for (size_t index = 1; index < overlap_set_indices.size(); ++index) {
+    const size_t offset = overlap_set_indices[index] - index + 1;
+    colocated_buffer_sets->erase(colocated_buffer_sets->begin() + offset);
+  }
+}
+
+namespace {
+// Checks that points-to set of 'instruction' is unambiguous and distinct
+// (ensured by CopyInsertion), then adds the buffer from the points-to set at
+// 'index' to 'colocated_set'.
+void AddBufferToColocatedSet(const HloInstruction* instruction,
+                             const ShapeIndex& index,
+                             const TuplePointsToAnalysis& points_to_analysis,
+                             std::vector<const LogicalBuffer*>* colocated_set) {
+  // CopyInsertion ensures root points-to set is unambiguous and distinct.
+  const auto& points_to = points_to_analysis.GetPointsToSet(instruction);
+  CHECK(!points_to.IsAmbiguous());
+  CHECK(points_to.IsDistinct());
+  colocated_set->push_back(points_to.element(index)[0]);
+}
+}  // namespace
+
+// Builds sets of buffers in 'colocated_buffer_sets' which should be colocated
+// in the same allocation (currently just supports kWhile and kCall).
+void BufferAssigner::BuildColocatedBufferSets(
+    const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
+    std::vector<ColocatedBufferSet>* colocated_buffer_sets) {
+  for (auto& computation : module->computations()) {
+    for (auto& instruction : computation->instructions()) {
+      const HloOpcode opcode = instruction->opcode();
+      if (opcode == HloOpcode::kWhile) {
+        HloInstruction* while_hlo = instruction.get();
+        TF_CHECK_OK(ShapeUtil::ForEachSubshape(
+            while_hlo->shape(),
+            [this, while_hlo, &points_to_analysis, colocated_buffer_sets](
+                const Shape& /*subshape*/, const ShapeIndex& index) {
+              vector<const LogicalBuffer*> colocated_set;
+              // Add while.init.
+              AddBufferToColocatedSet(while_hlo->operand(0), index,
+                                      points_to_analysis, &colocated_set);
+              // Add while.result.
+              AddBufferToColocatedSet(while_hlo, index, points_to_analysis,
+                                      &colocated_set);
+              // Add while.cond.parameter.
+              AddBufferToColocatedSet(
+                  while_hlo->while_condition()->parameter_instruction(0), index,
+                  points_to_analysis, &colocated_set);
+              // Add while.body.parameter.
+              AddBufferToColocatedSet(
+                  while_hlo->while_body()->parameter_instruction(0), index,
+                  points_to_analysis, &colocated_set);
+              // Add while.body.root.
+              AddBufferToColocatedSet(
+                  while_hlo->while_body()->root_instruction(), index,
+                  points_to_analysis, &colocated_set);
+              AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
+              return tensorflow::Status::OK();
+            }));
+      } else if (opcode == HloOpcode::kCall) {
+        HloInstruction* call_hlo = instruction.get();
+        HloInstruction* root_hlo = call_hlo->to_apply()->root_instruction();
+        TF_CHECK_OK(ShapeUtil::ForEachSubshape(
+            call_hlo->shape(),
+            [this, call_hlo, root_hlo, &points_to_analysis,
+             colocated_buffer_sets](const Shape& /*subshape*/,
+                                    const ShapeIndex& index) {
+              vector<const LogicalBuffer*> colocated_set;
+              // Add call.result.
+              AddBufferToColocatedSet(call_hlo, index, points_to_analysis,
+                                      &colocated_set);
+              // Add call.subcomputation.root.
+              AddBufferToColocatedSet(root_hlo, index, points_to_analysis,
+                                      &colocated_set);
+              AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets);
+              return tensorflow::Status::OK();
+            }));
       }
-      HloInstruction* while_hlo = instruction.get();
-      TF_CHECK_OK(ShapeUtil::ForEachSubshape(
-          while_hlo->shape(),
-          [this, &points_to_analysis, &while_hlo, &colocated_buffer_sets](
-              const Shape& /*subshape*/, const ShapeIndex& index) {
-            ColocatedBufferSet colocated_buffer_set;
-            // Add while.init.
-            AddBufferToColocatedBufferSet(while_hlo->operand(0), index,
-                                          points_to_analysis,
-                                          &colocated_buffer_set);
-            // Add while.result.
-            AddBufferToColocatedBufferSet(while_hlo, index, points_to_analysis,
-                                          &colocated_buffer_set);
-            // Add while.cond.parameter.
-            AddBufferToColocatedBufferSet(
-                while_hlo->while_condition()->parameter_instruction(0), index,
-                points_to_analysis, &colocated_buffer_set);
-            // Add while.body.parameter.
-            AddBufferToColocatedBufferSet(
-                while_hlo->while_body()->parameter_instruction(0), index,
-                points_to_analysis, &colocated_buffer_set);
-            // Add while.body.root.
-            AddBufferToColocatedBufferSet(
-                while_hlo->while_body()->root_instruction(), index,
-                points_to_analysis, &colocated_buffer_set);
-
-            colocated_buffer_sets.push_back(std::move(colocated_buffer_set));
-            return tensorflow::Status::OK();
-          }));
     }
   }
-  return colocated_buffer_sets;
 }
 
 // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same
 // allocation in 'assignment'.
 void BufferAssigner::AssignColocatedBufferSets(
     const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
-    BufferAssignment* assignment) {
-  for (const auto& colocated_buffer_set : colocated_buffer_sets) {
+    BufferAssignment* assignment,
+    tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
+    tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations) {
+  for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
     BufferAllocation* allocation = nullptr;
-    for (const auto& buffer : colocated_buffer_set) {
-      if (colocated_buffers_.find(buffer) != colocated_buffers_.end()) {
-        // ColocatedBufferSet duplicates can occur if a buffer is forwarded
-        // from one instruction to another (i.e. while.body param to root).
-        continue;
-      }
+    for (const LogicalBuffer* buffer : colocated_buffer_set) {
       if (allocation == nullptr) {
         // TODO(b/32491382) Avoid current trivial solution of using new
         // allocations for each colocated buffer set. When liveness has
@@ -675,12 +715,12 @@ void BufferAssigner::AssignColocatedBufferSets(
         allocation = assignment->NewAllocation(*buffer, buffer_size_(*buffer),
                                                /*is_thread_local=*/false,
                                                /*is_reusable=*/true);
-        colocated_buffer_allocations_.insert(allocation->index());
+        colocated_allocations->insert(allocation->index());
       } else {
         assignment->AddAssignment(*buffer, allocation,
                                   /*colocated_buffer=*/true);
       }
-      colocated_buffers_.insert(buffer);
+      colocated_buffers->insert(buffer);
     }
   }
 }
@@ -709,9 +749,9 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
 
   // Set of HLO's to allocate if hlos_to_allocate is given. Passed as a set to
   // AssignBuffersForComputation for fast membership testing.
-  std::unique_ptr<std::unordered_set<const HloInstruction*>> hlo_set;
+  std::unique_ptr<tensorflow::gtl::FlatSet<const HloInstruction*>> hlo_set;
   if (hlos_to_allocate != nullptr) {
-    hlo_set = MakeUnique<std::unordered_set<const HloInstruction*>>(
+    hlo_set = MakeUnique<tensorflow::gtl::FlatSet<const HloInstruction*>>(
         hlos_to_allocate->begin(), hlos_to_allocate->end());
   }
 
@@ -723,22 +763,26 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
   // Once b/32491382 enables module-level liveness analysis, we may be able
   // to assign colocated buffers (or at least reuse their allocation for
   // buffers outside of the set) in AssignBuffersForComputation.
+  tensorflow::gtl::FlatSet<const LogicalBuffer*> colocated_buffers;
+  tensorflow::gtl::FlatSet<BufferAllocation::Index> colocated_allocations;
   if (colocate_related_buffers_) {
-    std::vector<ColocatedBufferSet> colocated_buffer_sets =
-        BuildColocatedBufferSets(module, assignment->points_to_analysis());
-    AssignColocatedBufferSets(colocated_buffer_sets, assignment.get());
+    std::vector<ColocatedBufferSet> colocated_buffer_sets;
+    BuildColocatedBufferSets(module, assignment->points_to_analysis(),
+                             &colocated_buffer_sets);
+    AssignColocatedBufferSets(colocated_buffer_sets, assignment.get(),
+                              &colocated_buffers, &colocated_allocations);
   }
 
   for (auto* computation : global_computations) {
     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
-        computation,
-        /*is_thread_local=*/false, hlo_set.get(), assignment.get()));
+        computation, /*is_thread_local=*/false, hlo_set.get(),
+        colocated_buffers, colocated_allocations, assignment.get()));
   }
   for (auto* computation : thread_local_computations) {
     TF_RET_CHECK(computation != module->entry_computation());
     TF_RETURN_IF_ERROR(AssignBuffersForComputation(
-        computation,
-        /*is_thread_local=*/true, hlo_set.get(), assignment.get()));
+        computation, /*is_thread_local=*/true, hlo_set.get(), colocated_buffers,
+        colocated_allocations, assignment.get()));
   }
 
   // Mark all buffers which may be live out of the entry computation as
@@ -747,17 +791,20 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
   auto root_instruction = entry->root_instruction();
   const PointsToSet& root_points_to =
       assignment->GetPointsToSet(root_instruction);
-  TF_RETURN_IF_ERROR(root_points_to.ForEachElement([&assignment](
-      const ShapeIndex& /*index*/, bool /*is_leaf*/,
-      const std::vector<const LogicalBuffer*>& buffers) {
-    for (auto buffer : buffers) {
-      if (assignment->HasAllocation(*buffer)) {
-        assignment->GetMutableAssignedAllocation(*buffer)->set_maybe_live_out(
-            true);
-      }
-    }
-    return tensorflow::Status::OK();
-  }));
+  TF_RETURN_IF_ERROR(root_points_to.ForEachElement(
+      [&assignment](const ShapeIndex& /*index*/, bool /*is_leaf*/,
+                    const std::vector<const LogicalBuffer*>& buffers) {
+        for (const LogicalBuffer* buffer : buffers) {
+          VLOG(3) << "maybe_live_out LogicalBuffer: " << buffer->ToString();
+          if (assignment->HasAllocation(*buffer)) {
+            BufferAllocation* alloc =
+                assignment->GetMutableAssignedAllocation(*buffer);
+            alloc->set_maybe_live_out(true);
+            VLOG(3) << "maybe_live_out BufferAllocation: " << alloc->ToString();
+          }
+        }
+        return tensorflow::Status::OK();
+      }));
 
   XLA_VLOG_LINES(2, assignment->ToString());
 
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index b484ea51b19..e7aeb35967e 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -33,6 +33,7 @@ limitations under the License.
 #include "tensorflow/compiler/xla/types.h"
 #include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/macros.h"
 #include "tensorflow/core/platform/types.h"
@@ -317,7 +318,10 @@ class BufferAssigner {
   // included.
   tensorflow::Status AssignBuffersForComputation(
       const HloComputation* computation, bool is_thread_local,
-      const std::unordered_set<const HloInstruction*>* hlos_to_allocate,
+      const tensorflow::gtl::FlatSet<const HloInstruction*>* hlos_to_allocate,
+      const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
+      const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
+          colocated_allocations,
       BufferAssignment* assignment);
 
   // Tries to assign the given instruction to the given buffer. Returns if the
@@ -330,27 +334,28 @@ class BufferAssigner {
   // alias. Explicitly handling these colocated buffers is necessary because
   // points-to analysis is computation level scope and does not recognize
   // aliasing across computations (b/32491382).
-  using ColocatedBufferSet = std::vector<const LogicalBuffer*>;
+  using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>;
 
   // Returns a vector of ColocatedBufferSet objects, where each
   // ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
   // which should be colocated in the same buffer allocation.
-  std::vector<ColocatedBufferSet> BuildColocatedBufferSets(
-      const HloModule* module, const TuplePointsToAnalysis& points_to_analysis);
+  void BuildColocatedBufferSets(
+      const HloModule* module, const TuplePointsToAnalysis& points_to_analysis,
+      std::vector<ColocatedBufferSet>* colocated_buffer_sets);
 
   // For each buffer set in 'colocated_buffer_sets', assigns all buffers in the
   // same set to the same buffer allocation in 'assignment'.
   void AssignColocatedBufferSets(
       const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
-      BufferAssignment* assignment);
+      BufferAssignment* assignment,
+      tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
+      tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations);
 
-  // Checks that points-to set of 'instruction' is unambiguous and distinct
-  // (ensured by CopyInsertion), then adds buffer from point-to set at 'index'
-  // to 'colocated_buffer_set'.
-  void AddBufferToColocatedBufferSet(
-      const HloInstruction* instruction, const ShapeIndex& index,
-      const TuplePointsToAnalysis& points_to_analysis,
-      BufferAssigner::ColocatedBufferSet* colocated_buffer_set);
+  // Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
+  // the invariant that all sets in 'colocated_buffer_sets' are disjoint.
+  void AddSetToColocatedBufferSets(
+      const std::vector<const LogicalBuffer*>& colocated_set,
+      std::vector<ColocatedBufferSet>* colocated_buffer_sets);
 
   const HloModule* module_;
 
@@ -360,12 +365,6 @@ class BufferAssigner {
   // Indicates whether related buffers should share the same buffer allocation.
   const bool colocate_related_buffers_;
 
-  // Set of colocated buffers populated in AssignColocatedBufferSets.
-  std::unordered_set<const LogicalBuffer*> colocated_buffers_;
-
-  // Set of allocations containing colocated buffers.
-  std::unordered_set<BufferAllocation::Index> colocated_buffer_allocations_;
-
   TF_DISALLOW_COPY_AND_ASSIGN(BufferAssigner);
 };
 
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index bfa5bee2935..b8841c35f68 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -989,6 +989,41 @@ TEST_F(BufferAssignmentTest, TupleCustomCallAsOutput) {
       GetAllocation(*assignment, custom_call, /*index=*/{1}).maybe_live_out());
 }
 
+TEST_F(BufferAssignmentTest, TupleCallAsOutput) {
+  // Test a computation which returns a tuple call value.
+  auto module = MakeUnique<HloModule>(TestName());
+  auto elem_shape = f32vec4_;
+  auto tuple_shape = ShapeUtil::MakeTupleShape({elem_shape});
+
+  auto sub_builder = HloComputation::Builder(TestName() + "_sub");
+  auto sub_param = sub_builder.AddInstruction(
+      HloInstruction::CreateParameter(0, elem_shape, "sub_param"));
+  auto sub_tuple =
+      sub_builder.AddInstruction(HloInstruction::CreateTuple({sub_param}));
+  auto sub_computation = module->AddEmbeddedComputation(sub_builder.Build());
+
+  auto builder = HloComputation::Builder(TestName());
+  auto param = builder.AddInstruction(
+      HloInstruction::CreateParameter(0, elem_shape, "param"));
+  auto call = builder.AddInstruction(
+      HloInstruction::CreateCall(tuple_shape, {param}, sub_computation));
+  module->AddEntryComputation(builder.Build());
+
+  auto assignment = RunBufferAssignment(module.get());
+
+  EXPECT_EQ(3, assignment->Allocations().size());
+  // Buffers for call are co-located with the sub-computation.
+  EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}),
+            GetAllocation(*assignment, sub_tuple, /*index=*/{}));
+  EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}),
+            GetAllocation(*assignment, sub_param, /*index=*/{}));
+  // The parameter isn't aliased with anything.
+  EXPECT_NE(GetTopLevelAllocation(*assignment, param),
+            GetTopLevelAllocation(*assignment, sub_tuple));
+  EXPECT_NE(GetTopLevelAllocation(*assignment, param),
+            GetTopLevelAllocation(*assignment, sub_param));
+}
+
 TEST_F(BufferAssignmentTest, BitcastAsOutput) {
   // Test a computation which returns a bitcast value.
   auto builder = HloComputation::Builder(TestName());
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 4a3934cff78..9cafe3ad690 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -234,16 +234,22 @@ Status CpuCompiler::RunHloPasses(HloModule* hlo_module,
       /*is_layout_sensitive=*/true,
       [](const Shape&, const Shape&) { return true; });
   pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/true);
+  // Outline ops in the entry computation into calls to subcomputations.
+  legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags();
+  if (flags->xla_cpu_parallel) {
+    pipeline.AddPass<ParallelizationPreparation>();
+  }
   // Copy insertion should be performed immediately before IR emission to
   // avoid inserting unnecessary copies (later pass adds an instruction which
   // materializes the value) or missing a necessary copy (later pass removes
   // an instruction which materializes a value).
   pipeline.AddPass<CopyInsertion>();
-  pipeline.AddPass<HloDCE>();
-  legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags();
   if (flags->xla_cpu_parallel) {
+    // Re-run the outlining, in case any copies were inserted into the entry
+    // computation.
     pipeline.AddPass<ParallelizationPreparation>();
   }
+  pipeline.AddPass<HloDCE>();
   return pipeline.Run(hlo_module).status();
 }
 
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
index 2a0afafbf61..f6b1dcae75a 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
@@ -91,9 +91,16 @@ StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
     outlined.insert(instructions_to_outline.begin(),
                     instructions_to_outline.end());
 
+    // Optimization to avoid replacing a single existing kCall with another
+    // kCall that just calls the first one.
+    if (instructions_to_outline.size() == 1 &&
+        instructions_to_outline[0]->opcode() == HloOpcode::kCall) {
+      continue;
+    }
+
     module->OutlineExpressionFromComputation(
         instructions_to_outline,
-        tensorflow::strings::StrCat("computation_for_", instruction->name()),
+        tensorflow::strings::StrCat("pp_", instruction->name()),
         entry_computation);
     changed = true;
   }
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index f2fc38aa1c8..91468fd35b0 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -28,6 +28,15 @@ limitations under the License.
 
 namespace xla {
 
+namespace {
+void DumpModule(const Compiler::HloDumper& dumper_, const HloModule& module,
+                const string& message) {
+  dumper_(module, message);
+  VLOG(2) << "HLO " << message << ":";
+  XLA_VLOG_LINES(2, module.ToString());
+}
+}  // namespace
+
 StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
   legacy_flags::HloPassPipelineFlags* flags =
       legacy_flags::GetHloPassPipelineFlags();
@@ -47,10 +56,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
     // Emit label containing: "after foo-pass, before bar-pass".
     message.clear();
     tensorflow::strings::StrAppend(&message, prefix, ", before ", pass->name());
-    dumper_(*module, message);
-
-    VLOG(2) << "HLO " << message << ":";
-    XLA_VLOG_LINES(2, module->ToString());
+    DumpModule(dumper_, *module, message);
 
     TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module));
 
@@ -58,7 +64,7 @@ StatusOr<bool> HloPassPipeline::Run(HloModule* module) {
     prefix.clear();
     tensorflow::strings::StrAppend(&prefix, name(), ": after ", pass->name());
   }
-  dumper_(*module, prefix + ", pipeline end");
+  DumpModule(dumper_, *module, prefix + ", pipeline end");
   return changed;
 }
 
diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc
index 1c96b730345..0b5e6d51277 100644
--- a/tensorflow/compiler/xla/tests/call_test.cc
+++ b/tensorflow/compiler/xla/tests/call_test.cc
@@ -60,6 +60,14 @@ class CallOpTest : public ClientLibraryTestBase {
     return build_status.ConsumeValueOrDie();
   }
 
+  Computation CreateR0F32TupleComputation() {
+    ComputationBuilder builder(client_, "Tuple");
+    builder.Tuple({builder.Parameter(0, r0f32_, "x")});
+    auto build_status = builder.Build();
+    EXPECT_IS_OK(build_status.status());
+    return build_status.ConsumeValueOrDie();
+  }
+
   Shape r0f32_ = ShapeUtil::MakeShape(F32, {});
   Shape r1s0f32_ = ShapeUtil::MakeShape(F32, {0});
   Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2});
@@ -94,6 +102,16 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) {
   ComputeAndCompareR1<float>(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f));
 }
 
+XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) {
+  ComputationBuilder builder(client_, TestName());
+  Computation callee = CreateR0F32TupleComputation();
+  auto elem = LiteralUtil::CreateR0<float>(42.0);
+  auto tuple = LiteralUtil::MakeTuple({elem.get()});
+  builder.Call(callee, {builder.ConstantLiteral(*elem)});
+
+  ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f));
+}
+
 }  // namespace
 }  // namespace xla
 
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
index 591fff338cd..5c7079267ba 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc
@@ -44,8 +44,9 @@ TEST_F(LocalClientAotTest, Constant) {
   OpaqueData opaque_data{100, 20, 3};
   void* parameters[] = {&opaque_data};
   float out = 0;
-  float tmp = 0;
-  void* temporary_buffers[] = {&out, &tmp, nullptr};
+  float tmp1 = 0;
+  float tmp2 = 0;
+  void* temporary_buffers[] = {&out, &tmp1, &tmp2, nullptr};
   SumAndDouble(&out, &run_options, parameters, temporary_buffers);
   EXPECT_EQ(out, 246.0f);
 
diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
index 50d9ee50835..eed51bd6ad4 100644
--- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
+++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
@@ -86,14 +86,14 @@ int main(int argc, char** argv) {
       client->CompileAheadOfTime({instance}, options).ConsumeValueOrDie();
   auto result = xla::unique_ptr_static_cast<xla::cpu::CpuAotCompilationResult>(
       std::move(results.front()));
-  // We should have two buffers, one for the result and one temporary buffer,
-  // and both should be float-sized.  It's lame to hard-code this, but we need
+  // It's lame to hard-code the buffer assignments, but we need
   // local_client_aot_test.cc to be able to easily invoke the function.
   CHECK_EQ(result->result_buffer_index(), 0);
-  CHECK_EQ(result->buffer_sizes().size(), 3);
+  CHECK_EQ(result->buffer_sizes().size(), 4);
   CHECK_EQ(result->buffer_sizes()[0], sizeof(float));  // result buffer
   CHECK_EQ(result->buffer_sizes()[1], sizeof(float));  // temp buffer
-  CHECK_EQ(result->buffer_sizes()[2], -1);
+  CHECK_EQ(result->buffer_sizes()[2], sizeof(float));  // temp buffer
+  CHECK_EQ(result->buffer_sizes()[3], -1);             // param buffer
   if (triple.isOSBinFormatELF()) {
     // Check the ELF magic.
     CHECK_EQ(result->object_file_data()[0], 0x7F);