From fce9ecb9cb44e03989c98367aa5a3b73c644a606 Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 28 Mar 2017 06:14:21 -0800
Subject: [PATCH] [XLA:CPU] Implements LocalClient support for parallel CPU
 backend. Change: 151446054

---
 .../jit/kernels/xla_local_launch_op.cc        |   2 +
 .../compiler/xla/client/local_client.cc       |   8 +
 tensorflow/compiler/xla/client/local_client.h |   4 +
 .../service/cpu/parallel_cpu_executable.cc    | 206 ++++++++++++++----
 .../xla/service/cpu/parallel_cpu_executable.h |  29 +++
 .../xla/tests/local_client_test_base.cc       |   9 +-
 6 files changed, 208 insertions(+), 50 deletions(-)

diff --git a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
index 70b9c6fb0fb..8b43c7c1564 100644
--- a/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_local_launch_op.cc
@@ -260,6 +260,8 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
     xla::ExecutableRunOptions run_options;
     run_options.set_stream(stream);
     run_options.set_allocator(&xla_allocator);
+    run_options.set_inter_op_thread_pool(
+        ctx->device()->tensorflow_cpu_worker_threads()->workers);
     run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
     Env* env = Env::Default();
     auto start_time = env->NowMicros();
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 1b764564912..bfd14bc1c01 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -309,6 +309,14 @@ int LocalClient::default_device_ordinal() const {
   return local_service_->backend().default_device_ordinal();
 }
 
+const Backend& LocalClient::backend() const {
+  return local_service_->backend();
+}
+
+Backend* LocalClient::mutable_backend() {
+  return local_service_->mutable_backend();
+}
+
 StatusOr<std::unique_ptr<LocalExecutable>> LocalClient::Compile(
     const Computation& computation,
     const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 7bffc75ab5f..2c467efcea1 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -224,6 +224,10 @@ class LocalClient : public Client {
   // capability).
   bool device_ordinal_supported(int device_ordinal) const;
 
+  // Returns the backend used to execute computations.
+  const Backend& backend() const;
+  Backend* mutable_backend();
+
  private:
   LocalService* local_service_;
 };
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
index bab3440e2c2..d727877ae3d 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
@@ -97,77 +97,81 @@ static void MarkLiveAddressesInOutput(
   }
 }
 
-StatusOr<perftools::gputools::DeviceMemoryBase>
-ParallelCpuExecutable::ExecuteOnStream(
-    const ServiceExecutableRunOptions* run_options,
-    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
-    HloExecutionProfile* hlo_execution_profile) {
-  se::Stream* stream = run_options->stream();
-  DeviceMemoryAllocator* memory_allocator = run_options->allocator();
-  VLOG(3) << "ExecuteOnStream arg size: " << arguments.size();
-  if (!arguments.empty()) {
-    VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque();
-  }
-
-  // Allocate the temporary buffers required for the computation.
-  se::StreamExecutor* stream_executor = stream->parent();
-  int device_ordinal = stream_executor->device_ordinal();
-  int64 buffer_count = assignment_->Allocations().size();
-  VLOG(3) << "temp buffer count: " << buffer_count;
-
-  std::vector<se::DeviceMemoryBase> device_allocations;
-  for (BufferAllocation::Index i = 0; i < buffer_count; ++i) {
+Status ParallelCpuExecutable::AllocateBuffers(
+    DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+    std::vector<perftools::gputools::DeviceMemoryBase>* buffers) {
+  CHECK_EQ(buffers->size(), assignment_->Allocations().size());
+  VLOG(3) << "Allocating " << assignment_->Allocations().size()
+          << " allocations for module " << module().name();
+  for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
+       ++i) {
     auto& allocation = assignment_->GetAllocation(i);
+
+    VLOG(3) << allocation.ToString();
+
     if (allocation.is_entry_computation_parameter()) {
-      // Buffers do not need to be allocated for parameters.
-      device_allocations.push_back(se::DeviceMemoryBase(nullptr));
+      VLOG(3) << "allocation #" << i << " is a parameter";
       continue;
     }
 
     if (allocation.is_thread_local()) {
-      // Buffers do not need to be allocated for thread-local temporaries.
-      device_allocations.push_back(se::DeviceMemoryBase(nullptr));
+      VLOG(3) << "buffer #" << i << " is thread-local";
       continue;
     }
 
-    TF_ASSIGN_OR_RETURN(
-        se::DeviceMemoryBase device_allocation,
-        memory_allocator->Allocate(device_ordinal, allocation.size()));
+    int64 buffer_size = allocation.size();
+    if (!(*buffers)[i].is_null()) {
+      VLOG(3) << "buffer #" << i
+              << " is in the preallocated result ShapedBuffer";
+    } else {
+      TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
+                                             device_ordinal, buffer_size));
 
-    if (VLOG_IS_ON(3)) {
-      VLOG(3) << "ParallelCpuExecutable allocating " << allocation.size()
-              << " bytes for allocation #" << i << " ["
-              << device_allocation.opaque() << "]";
-      std::vector<string> parts;
-      for (const auto& buffer_offset_size : allocation.assigned_buffers()) {
-        const LogicalBuffer& buffer = *buffer_offset_size.first;
-        parts.push_back(tensorflow::strings::StrCat(
-            buffer.instruction()->parent()->name(), "::", buffer.ToString()));
-      }
-      VLOG(3) << "  " << tensorflow::str_util::Join(parts, ", ");
+      VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
+              << (*buffers)[i].opaque() << "]";
     }
 
-    device_allocations.push_back(device_allocation);
     // Since the output buffer and all the temporary buffers were written into
     // by the JITed code, msan has no way of knowing their memory was
     // initialized. Mark them initialized so that msan doesn't flag loads from
     // these buffers.
-    TF_ANNOTATE_MEMORY_IS_INITIALIZED(device_allocation.opaque(),
-                                      allocation.size());
+    TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
   }
 
   TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
                       assignment_->GetUniqueTopLevelOutputSlice());
-  const BufferAllocation::Index result_index = result_slice.index();
-  VLOG(3) << "result index: " << result_index;
+  VLOG(3) << "result index: " << result_slice.index();
 
+  return Status::OK();
+}
+
+Status ParallelCpuExecutable::ExecuteComputeFunctions(
+    const ExecutableRunOptions* run_options,
+    tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
+    HloExecutionProfile* hlo_execution_profile) {
+  std::vector<se::DeviceMemoryBase> argument_buffers(arguments.size());
+  for (int i = 0; i < arguments.size(); ++i) {
+    TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape()));
+    argument_buffers[i] = arguments[i]->buffer(/*index=*/{});
+  }
+  return ExecuteComputeFunctions(run_options, argument_buffers, buffers,
+                                 hlo_execution_profile);
+}
+
+Status ParallelCpuExecutable::ExecuteComputeFunctions(
+    const ExecutableRunOptions* run_options,
+    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
+    HloExecutionProfile* hlo_execution_profile) {
   // Allocate profiling counters for each hlo instruction that we would like to
   // profile.  Allocate an additional profile counter for the entire
   // computation.
   std::vector<uint64> profile_counters(hlo_to_profile_idx_.size() + 1);
 
   std::vector<void*> buffer_pointers;
-  for (auto& device_allocation : device_allocations) {
+  buffer_pointers.reserve(buffers.size());
+  for (auto device_allocation : buffers) {
     buffer_pointers.push_back(device_allocation.opaque());
   }
 
@@ -210,8 +214,7 @@ ParallelCpuExecutable::ExecuteOnStream(
 
   void** temps_array = buffer_pointers.data();
   uint64* profile_counters_array = profile_counters.data();
-  auto* thread_pool =
-      CHECK_NOTNULL(run_options->run_options().inter_op_thread_pool());
+  auto* thread_pool = CHECK_NOTNULL(run_options->inter_op_thread_pool());
   tensorflow::mutex completion_queue_lock;
   tensorflow::condition_variable completion_queue_cv;
   std::deque<HloInstruction*> completion_queue;
@@ -310,6 +313,42 @@ ParallelCpuExecutable::ExecuteOnStream(
     }
   }
 
+  return Status::OK();
+}
+
+StatusOr<perftools::gputools::DeviceMemoryBase>
+ParallelCpuExecutable::ExecuteOnStream(
+    const ServiceExecutableRunOptions* run_options,
+    tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
+    HloExecutionProfile* hlo_execution_profile) {
+  se::Stream* stream = run_options->stream();
+  DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+  VLOG(3) << "ExecuteOnStream arg size: " << arguments.size();
+  if (!arguments.empty()) {
+    VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque();
+  }
+
+  // Allocate the temporary buffers required for the computation.
+  se::StreamExecutor* stream_executor = stream->parent();
+  int device_ordinal = stream_executor->device_ordinal();
+  int64 buffer_count = assignment_->Allocations().size();
+  VLOG(3) << "temp buffer count: " << buffer_count;
+
+  std::vector<se::DeviceMemoryBase> device_allocations(
+      assignment_->Allocations().size());
+  TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator,
+                                     stream->parent()->device_ordinal(),
+                                     &device_allocations));
+
+  TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
+                      assignment_->GetUniqueTopLevelOutputSlice());
+  const BufferAllocation::Index result_index = result_slice.index();
+  VLOG(3) << "result index: " << result_index;
+
+  TF_RETURN_IF_ERROR(ExecuteComputeFunctions(&run_options->run_options(),
+                                             arguments, device_allocations,
+                                             hlo_execution_profile));
+
   // Mark the buffers that are actually live (used in the output) when the
   // computation finishes executing.
   std::unordered_set<const void*> marked_addresses;
@@ -345,8 +384,74 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
     const ServiceExecutableRunOptions* run_options,
     tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
     HloExecutionProfile* hlo_execution_profile) {
-  return Unimplemented(
-      "ParallelCpuExecutable not supported yet with LocalService execution");
+  if (GetRootPointsToSet().IsAmbiguous()) {
+    return Unimplemented("Points-to set of root instruction is ambiguous");
+  }
+
+  se::Stream* stream = run_options->stream();
+  DeviceMemoryAllocator* memory_allocator = run_options->allocator();
+  std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
+
+  TF_ASSIGN_OR_RETURN(std::unique_ptr<ShapedBuffer> result_buffer,
+                      ShapedBuffer::MakeShapedBuffer(
+                          result_shape(), stream->parent()->platform(),
+                          stream->parent()->device_ordinal()));
+
+  TF_RETURN_IF_ERROR(AllocateBuffers(
+      memory_allocator, stream->parent()->device_ordinal(), &buffers));
+
+  TF_RETURN_IF_ERROR(ExecuteComputeFunctions(
+      &run_options->run_options(), arguments, buffers, hlo_execution_profile));
+
+  // Copy DeviceMemoryBase values which contain the array(s) of the result into
+  // the respective location in ShapedBuffer which is returned to the caller.
+  std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
+  TF_RETURN_IF_ERROR(
+      result_buffer->mutable_shape_index_to_buffer_entry()
+          ->ForEachMutableElement(
+              [&buffers, &buffers_in_result, &result_buffer, this](
+                  const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) {
+                if (is_leaf) {
+                  const std::vector<const LogicalBuffer*>& sources =
+                      this->GetRootPointsToSet().element(index);
+                  // The points to set is unambiguous so the set should be a
+                  // singleton.
+                  CHECK_EQ(1, sources.size());
+                  const LogicalBuffer* buffer_source = sources[0];
+                  HloInstruction* src = buffer_source->instruction();
+
+                  // The source for this result buffer can be a nested buffer
+                  // such as a tuple element.
+
+                  // The source instruction should have a non-parameter buffer
+                  // assigned.
+                  TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
+                                      this->assignment_->GetUniqueSlice(
+                                          src, buffer_source->index()));
+                  CHECK(!slice.allocation()->is_entry_computation_parameter());
+
+                  const BufferAllocation::Index buffer_index = slice.index();
+                  const se::DeviceMemoryBase& buffer = buffers[buffer_index];
+                  CHECK(!buffer.is_null() || buffer.size() == 0);
+                  *buffer_entry = result_buffer->mutable_buffers()->size();
+                  result_buffer->mutable_buffers()->push_back(buffer);
+                  buffers_in_result[buffer_index] = true;
+                }
+                return Status::OK();
+              }));
+
+  // Free all buffers not in the result.
+  for (size_t i = 0; i < buffers.size(); ++i) {
+    se::DeviceMemoryBase alloc = buffers[i];
+    if (!buffers_in_result[i] && !alloc.is_null()) {
+      VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
+              << alloc.opaque() << "]";
+      TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
+          stream->parent()->device_ordinal(), &alloc));
+    }
+  }
+
+  return std::move(result_buffer);
 }
 
 StatusOr<perftools::gputools::DeviceMemoryBase>
@@ -358,5 +463,10 @@ ParallelCpuExecutable::ExecuteAsyncOnStream(
       "Asynchronous execution on stream is not yet supported on CPU.");
 }
 
+const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const {
+  return assignment_->points_to_analysis().GetPointsToSet(
+      module().entry_computation()->root_instruction());
+}
+
 }  // namespace cpu
 }  // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
index 7ce059bb1da..7223de9f079 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
@@ -84,6 +84,35 @@ class ParallelCpuExecutable : public Executable {
   }
 
  private:
+  // Allocate buffers required for execution and assign them to the elements of
+  // "buffers". "buffers" should be sized to the number of buffers in buffer
+  // assignment. Each vector element corresponds to a particular Index. If
+  // a vector element already contains a non-null DeviceMemoryBase, then no
+  // buffer is assigned for this element.
+  Status AllocateBuffers(
+      DeviceMemoryAllocator* memory_allocator, int device_ordinal,
+      std::vector<perftools::gputools::DeviceMemoryBase>* buffers);
+
+  // Calls the generated functions in 'function_names_', performing the
+  // computation with the given arguments using the supplied buffers.
+  Status ExecuteComputeFunctions(
+      const ExecutableRunOptions* run_options,
+      tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+          arguments,
+      tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+          buffers,
+      HloExecutionProfile* hlo_execution_profile);
+  Status ExecuteComputeFunctions(
+      const ExecutableRunOptions* run_options,
+      tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
+      tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
+          buffers,
+      HloExecutionProfile* hlo_execution_profile);
+
+  // Returns the points-to set of the root instruction of the entry
+  // computation. Uses points-to analysis from buffer assignment.
+  const PointsToSet& GetRootPointsToSet() const;
+
   // The JIT containing compiled modules.
   tensorflow::mutex jit_mutex_;
   std::unique_ptr<SimpleOrcJIT> jit_ GUARDED_BY(jit_mutex_);
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index 8948d77aec8..7fe4c9020f4 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -187,8 +187,13 @@ ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions()
 }
 
 ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
-  return ExecutableRunOptions().set_allocator(
-      GetOrCreateAllocator(local_client_->platform()));
+  ExecutableRunOptions run_options;
+  run_options.set_inter_op_thread_pool(
+      local_client_->backend().inter_op_thread_pool());
+  run_options.set_intra_op_thread_pool(
+      local_client_->backend().eigen_intra_op_thread_pool_device());
+  run_options.set_allocator(GetOrCreateAllocator(local_client_->platform()));
+  return run_options;
 }
 
 std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocallyOrDie(