From 077b553fda4ecda0d478cbed2a1fbdeae0c42e58 Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins@google.com>
Date: Fri, 1 May 2020 14:16:44 -0700
Subject: [PATCH] [XLA:Python] Specify a 2MiB stack size for host stream
 threads.

[StreamExecutor] Allow HostExecutor users to control the stack sizes of threads used for HostStream via.

Also include non_portable_tags in the keys used when creating an Executor. There seems to be no good reason that it is omitted.

Will fix https://github.com/google/jax/issues/432 when included in a jaxlib release.

PiperOrigin-RevId: 309472318
Change-Id: Ia2535616047390d6bf6f2da82a666a321dcc9f5d
---
 tensorflow/compiler/xla/pjrt/BUILD            |  1 +
 tensorflow/compiler/xla/pjrt/cpu_device.cc    |  9 ++++++--
 .../xla/service/interpreter/executor.h        |  3 ++-
 tensorflow/stream_executor/device_options.h   |  3 ++-
 tensorflow/stream_executor/host/BUILD         |  1 +
 .../stream_executor/host/host_gpu_executor.cc | 22 +++++++++++++++++++
 .../stream_executor/host/host_gpu_executor.h  | 13 +++++------
 .../stream_executor/host/host_stream.cc       | 15 +++++++++++--
 tensorflow/stream_executor/host/host_stream.h |  4 +++-
 tensorflow/stream_executor/lib/status.h       |  3 +++
 10 files changed, 60 insertions(+), 14 deletions(-)

diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD
index ddc4c007c09..dbd33705d0e 100644
--- a/tensorflow/compiler/xla/pjrt/BUILD
+++ b/tensorflow/compiler/xla/pjrt/BUILD
@@ -167,6 +167,7 @@ cc_library(
         "//tensorflow/compiler/xla:statusor",
         "//tensorflow/compiler/xla/client:client_library",
         "//tensorflow/compiler/xla/service:platform_util",
+        "@com_google_absl//absl/strings",
     ],
 )
 
diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc
index f2bc472ed09..92648c26e40 100644
--- a/tensorflow/compiler/xla/pjrt/cpu_device.cc
+++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc
@@ -15,6 +15,7 @@ limitations under the License.
 
 #include "tensorflow/compiler/xla/pjrt/cpu_device.h"
 
+#include "absl/strings/str_cat.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/service/platform_util.h"
 
@@ -40,8 +41,12 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
 
   std::vector<std::unique_ptr<Device>> devices;
   for (int i = 0; i < client->device_count(); ++i) {
-    se::StreamExecutor* executor =
-        client->backend().stream_executor(i).ValueOrDie();
+    se::StreamExecutorConfig config;
+    config.ordinal = i;
+    config.device_options.non_portable_tags["host_thread_stack_size_in_bytes"] =
+        absl::StrCat(2048 * 1024);
+    TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
+                        platform->GetExecutor(config));
     auto device_state = absl::make_unique<LocalDeviceState>(
         executor, client, LocalDeviceState::kSynchronous, asynchronous,
         /*allow_event_reuse=*/false);
diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h
index 3c35fda55f1..9e4bdeb2b2d 100644
--- a/tensorflow/compiler/xla/service/interpreter/executor.h
+++ b/tensorflow/compiler/xla/service/interpreter/executor.h
@@ -203,7 +203,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
 
   std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
       override {
-    return std::unique_ptr<internal::StreamInterface>(new host::HostStream());
+    return std::unique_ptr<internal::StreamInterface>(
+        new host::HostStream(/*thread_stack_size=*/0));
   }
 
   std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
diff --git a/tensorflow/stream_executor/device_options.h b/tensorflow/stream_executor/device_options.h
index 00eb8c8dbb0..98660441b42 100644
--- a/tensorflow/stream_executor/device_options.h
+++ b/tensorflow/stream_executor/device_options.h
@@ -64,7 +64,8 @@ struct DeviceOptions {
   unsigned flags() const { return flags_; }
 
   bool operator==(const DeviceOptions& other) const {
-    return flags_ == other.flags_;
+    return flags_ == other.flags_ &&
+           non_portable_tags == other.non_portable_tags;
   }
 
   bool operator!=(const DeviceOptions& other) const {
diff --git a/tensorflow/stream_executor/host/BUILD b/tensorflow/stream_executor/host/BUILD
index 362e2199284..be5af1f6ee7 100644
--- a/tensorflow/stream_executor/host/BUILD
+++ b/tensorflow/stream_executor/host/BUILD
@@ -112,6 +112,7 @@ cc_library(
         "//tensorflow/stream_executor:stream_executor_pimpl",
         "//tensorflow/stream_executor:timer",
         "//tensorflow/stream_executor/lib",
+        "@com_google_absl//absl/strings",
         "@com_google_absl//absl/synchronization",
     ],
     alwayslink = True,
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc
index 5242420fcdb..d6fd0ce9821 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.cc
+++ b/tensorflow/stream_executor/host/host_gpu_executor.cc
@@ -19,6 +19,8 @@ limitations under the License.
 
 #include <string.h>
 
+#include "absl/strings/numbers.h"
+#include "absl/strings/str_cat.h"
 #include "absl/synchronization/notification.h"
 #include "tensorflow/core/platform/mem.h"
 #include "tensorflow/core/platform/profile_utils/cpu_utils.h"
@@ -42,6 +44,20 @@ HostExecutor::HostExecutor(const PluginConfig &plugin_config)
 
 HostExecutor::~HostExecutor() {}
 
+port::Status HostExecutor::Init(int device_ordinal,
+                                DeviceOptions device_options) {
+  auto it =
+      device_options.non_portable_tags.find("host_thread_stack_size_in_bytes");
+  if (it != device_options.non_portable_tags.end()) {
+    if (!absl::SimpleAtoi(it->second, &thread_stack_size_in_bytes_)) {
+      return port::InvalidArgumentError(absl::StrCat(
+          "Unable to parse host_thread_stack_size_in_bytes as an integer: ",
+          it->second));
+    }
+  }
+  return port::Status::OK();
+}
+
 DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) {
   CHECK_EQ(memory_space, 0);
   // Use a minimum alignment of 64 bytes to be friendly to AVX512 code.
@@ -332,5 +348,11 @@ rng::RngSupport *HostExecutor::CreateRng() {
   return status.ValueOrDie()(this);
 }
 
+std::unique_ptr<internal::StreamInterface>
+HostExecutor::GetStreamImplementation() {
+  return std::unique_ptr<internal::StreamInterface>(
+      new HostStream(thread_stack_size_in_bytes_));
+}
+
 }  // namespace host
 }  // namespace stream_executor
diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h
index d40a7a88015..c971ec89bf0 100644
--- a/tensorflow/stream_executor/host/host_gpu_executor.h
+++ b/tensorflow/stream_executor/host/host_gpu_executor.h
@@ -46,9 +46,9 @@ class HostExecutor : public internal::StreamExecutorInterface {
   explicit HostExecutor(const PluginConfig &plugin_config);
   ~HostExecutor() override;
 
-  port::Status Init(int device_ordinal, DeviceOptions device_options) override {
-    return port::Status::OK();
-  }
+  // The stack size used for host streams can be set via
+  // device_options.non_portable_tags["host_stack_size"].
+  port::Status Init(int device_ordinal, DeviceOptions device_options) override;
 
   port::Status GetKernel(const MultiKernelLoaderSpec &spec,
                          KernelBase *kernel) override {
@@ -184,10 +184,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
     return nullptr;
   }
 
-  std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
-      override {
-    return std::unique_ptr<internal::StreamInterface>(new HostStream());
-  }
+  std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override;
 
   std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
     return std::unique_ptr<internal::TimerInterface>(new HostTimer());
@@ -197,6 +194,8 @@ class HostExecutor : public internal::StreamExecutorInterface {
 
  private:
   const PluginConfig plugin_config_;
+  // Size of thread stacks for streams in bytes. '0' means "the default size".
+  size_t thread_stack_size_in_bytes_ = 0;
 };
 
 }  // namespace host
diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc
index 413edc6739a..320b79ff37a 100644
--- a/tensorflow/stream_executor/host/host_stream.cc
+++ b/tensorflow/stream_executor/host/host_stream.cc
@@ -24,9 +24,20 @@ limitations under the License.
 namespace stream_executor {
 namespace host {
 
-HostStream::HostStream()
+namespace {
+
+port::ThreadOptions GetThreadOptions(size_t stack_size_in_bytes) {
+  port::ThreadOptions options;
+  options.stack_size = stack_size_in_bytes;
+  return options;
+}
+
+}  // namespace
+
+HostStream::HostStream(size_t stack_size_in_bytes)
     : thread_(port::Env::Default()->StartThread(
-          port::ThreadOptions(), "host_executor", [this]() { WorkLoop(); })) {}
+          GetThreadOptions(stack_size_in_bytes), "host_executor",
+          [this]() { WorkLoop(); })) {}
 
 HostStream::~HostStream() {
   {
diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h
index 0a353d4a19b..2ee3f1f449c 100644
--- a/tensorflow/stream_executor/host/host_stream.h
+++ b/tensorflow/stream_executor/host/host_stream.h
@@ -31,7 +31,9 @@ namespace host {
 
 class HostStream : public internal::StreamInterface {
  public:
-  HostStream();
+  // stack_size_in_bytes may be '0', meaning "use the default thread stack
+  // size".
+  explicit HostStream(size_t stack_size_in_bytes);
   ~HostStream() override;
 
   bool EnqueueTask(std::function<void()> task);
diff --git a/tensorflow/stream_executor/lib/status.h b/tensorflow/stream_executor/lib/status.h
index 87269b4591a..170a7955979 100644
--- a/tensorflow/stream_executor/lib/status.h
+++ b/tensorflow/stream_executor/lib/status.h
@@ -36,6 +36,9 @@ using Status = tensorflow::Status;
 inline Status UnimplementedError(absl::string_view message) {
   return Status(error::UNIMPLEMENTED, message);
 }
+inline Status InvalidArgumentError(absl::string_view message) {
+  return Status(error::INVALID_ARGUMENT, message);
+}
 inline Status InternalError(absl::string_view message) {
   return Status(error::INTERNAL, message);
 }