From 077b553fda4ecda0d478cbed2a1fbdeae0c42e58 Mon Sep 17 00:00:00 2001 From: Peter Hawkins <phawkins@google.com> Date: Fri, 1 May 2020 14:16:44 -0700 Subject: [PATCH] [XLA:Python] Specify a 2MiB stack size for host stream threads. [StreamExecutor] Allow HostExecutor users to control the stack sizes of threads used for HostStream via. Also include non_portable_tags in the keys used when creating an Executor. There seems to be no good reason that it is omitted. Will fix https://github.com/google/jax/issues/432 when included in a jaxlib release. PiperOrigin-RevId: 309472318 Change-Id: Ia2535616047390d6bf6f2da82a666a321dcc9f5d --- tensorflow/compiler/xla/pjrt/BUILD | 1 + tensorflow/compiler/xla/pjrt/cpu_device.cc | 9 ++++++-- .../xla/service/interpreter/executor.h | 3 ++- tensorflow/stream_executor/device_options.h | 3 ++- tensorflow/stream_executor/host/BUILD | 1 + .../stream_executor/host/host_gpu_executor.cc | 22 +++++++++++++++++++ .../stream_executor/host/host_gpu_executor.h | 13 +++++------ .../stream_executor/host/host_stream.cc | 15 +++++++++++-- tensorflow/stream_executor/host/host_stream.h | 4 +++- tensorflow/stream_executor/lib/status.h | 3 +++ 10 files changed, 60 insertions(+), 14 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index ddc4c007c09..dbd33705d0e 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -167,6 +167,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/service:platform_util", + "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index f2bc472ed09..92648c26e40 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/cpu_device.h" +#include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -40,8 +41,12 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) { std::vector<std::unique_ptr<Device>> devices; for (int i = 0; i < client->device_count(); ++i) { - se::StreamExecutor* executor = - client->backend().stream_executor(i).ValueOrDie(); + se::StreamExecutorConfig config; + config.ordinal = i; + config.device_options.non_portable_tags["host_thread_stack_size_in_bytes"] = + absl::StrCat(2048 * 1024); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + platform->GetExecutor(config)); auto device_state = absl::make_unique<LocalDeviceState>( executor, client, LocalDeviceState::kSynchronous, asynchronous, /*allow_event_reuse=*/false); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index 3c35fda55f1..9e4bdeb2b2d 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -203,7 +203,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface { std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override { - return std::unique_ptr<internal::StreamInterface>(new host::HostStream()); + return std::unique_ptr<internal::StreamInterface>( + new host::HostStream(/*thread_stack_size=*/0)); } std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override { diff --git a/tensorflow/stream_executor/device_options.h b/tensorflow/stream_executor/device_options.h index 00eb8c8dbb0..98660441b42 100644 --- a/tensorflow/stream_executor/device_options.h +++ b/tensorflow/stream_executor/device_options.h @@ -64,7 +64,8 @@ struct DeviceOptions { unsigned flags() const { return flags_; } bool operator==(const DeviceOptions& other) const { - return flags_ == other.flags_; + return flags_ == other.flags_ && + non_portable_tags == other.non_portable_tags; } bool operator!=(const DeviceOptions& other) const { diff --git a/tensorflow/stream_executor/host/BUILD b/tensorflow/stream_executor/host/BUILD index 362e2199284..be5af1f6ee7 100644 --- a/tensorflow/stream_executor/host/BUILD +++ b/tensorflow/stream_executor/host/BUILD @@ -112,6 +112,7 @@ cc_library( "//tensorflow/stream_executor:stream_executor_pimpl", "//tensorflow/stream_executor:timer", "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], alwayslink = True, diff --git a/tensorflow/stream_executor/host/host_gpu_executor.cc b/tensorflow/stream_executor/host/host_gpu_executor.cc index 5242420fcdb..d6fd0ce9821 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.cc +++ b/tensorflow/stream_executor/host/host_gpu_executor.cc @@ -19,6 +19,8 @@ limitations under the License. #include <string.h> +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" #include "absl/synchronization/notification.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/profile_utils/cpu_utils.h" @@ -42,6 +44,20 @@ HostExecutor::HostExecutor(const PluginConfig &plugin_config) HostExecutor::~HostExecutor() {} +port::Status HostExecutor::Init(int device_ordinal, + DeviceOptions device_options) { + auto it = + device_options.non_portable_tags.find("host_thread_stack_size_in_bytes"); + if (it != device_options.non_portable_tags.end()) { + if (!absl::SimpleAtoi(it->second, &thread_stack_size_in_bytes_)) { + return port::InvalidArgumentError(absl::StrCat( + "Unable to parse host_thread_stack_size_in_bytes as an integer: ", + it->second)); + } + } + return port::Status::OK(); +} + DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) { CHECK_EQ(memory_space, 0); // Use a minimum alignment of 64 bytes to be friendly to AVX512 code. @@ -332,5 +348,11 @@ rng::RngSupport *HostExecutor::CreateRng() { return status.ValueOrDie()(this); } +std::unique_ptr<internal::StreamInterface> +HostExecutor::GetStreamImplementation() { + return std::unique_ptr<internal::StreamInterface>( + new HostStream(thread_stack_size_in_bytes_)); +} + } // namespace host } // namespace stream_executor diff --git a/tensorflow/stream_executor/host/host_gpu_executor.h b/tensorflow/stream_executor/host/host_gpu_executor.h index d40a7a88015..c971ec89bf0 100644 --- a/tensorflow/stream_executor/host/host_gpu_executor.h +++ b/tensorflow/stream_executor/host/host_gpu_executor.h @@ -46,9 +46,9 @@ class HostExecutor : public internal::StreamExecutorInterface { explicit HostExecutor(const PluginConfig &plugin_config); ~HostExecutor() override; - port::Status Init(int device_ordinal, DeviceOptions device_options) override { - return port::Status::OK(); - } + // The stack size used for host streams can be set via + // device_options.non_portable_tags["host_stack_size"]. + port::Status Init(int device_ordinal, DeviceOptions device_options) override; port::Status GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) override { @@ -184,10 +184,7 @@ class HostExecutor : public internal::StreamExecutorInterface { return nullptr; } - std::unique_ptr<internal::StreamInterface> GetStreamImplementation() - override { - return std::unique_ptr<internal::StreamInterface>(new HostStream()); - } + std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override; std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override { return std::unique_ptr<internal::TimerInterface>(new HostTimer()); @@ -197,6 +194,8 @@ class HostExecutor : public internal::StreamExecutorInterface { private: const PluginConfig plugin_config_; + // Size of thread stacks for streams in bytes. '0' means "the default size". + size_t thread_stack_size_in_bytes_ = 0; }; } // namespace host diff --git a/tensorflow/stream_executor/host/host_stream.cc b/tensorflow/stream_executor/host/host_stream.cc index 413edc6739a..320b79ff37a 100644 --- a/tensorflow/stream_executor/host/host_stream.cc +++ b/tensorflow/stream_executor/host/host_stream.cc @@ -24,9 +24,20 @@ limitations under the License. namespace stream_executor { namespace host { -HostStream::HostStream() +namespace { + +port::ThreadOptions GetThreadOptions(size_t stack_size_in_bytes) { + port::ThreadOptions options; + options.stack_size = stack_size_in_bytes; + return options; +} + +} // namespace + +HostStream::HostStream(size_t stack_size_in_bytes) : thread_(port::Env::Default()->StartThread( - port::ThreadOptions(), "host_executor", [this]() { WorkLoop(); })) {} + GetThreadOptions(stack_size_in_bytes), "host_executor", + [this]() { WorkLoop(); })) {} HostStream::~HostStream() { { diff --git a/tensorflow/stream_executor/host/host_stream.h b/tensorflow/stream_executor/host/host_stream.h index 0a353d4a19b..2ee3f1f449c 100644 --- a/tensorflow/stream_executor/host/host_stream.h +++ b/tensorflow/stream_executor/host/host_stream.h @@ -31,7 +31,9 @@ namespace host { class HostStream : public internal::StreamInterface { public: - HostStream(); + // stack_size_in_bytes may be '0', meaning "use the default thread stack + // size". + explicit HostStream(size_t stack_size_in_bytes); ~HostStream() override; bool EnqueueTask(std::function<void()> task); diff --git a/tensorflow/stream_executor/lib/status.h b/tensorflow/stream_executor/lib/status.h index 87269b4591a..170a7955979 100644 --- a/tensorflow/stream_executor/lib/status.h +++ b/tensorflow/stream_executor/lib/status.h @@ -36,6 +36,9 @@ using Status = tensorflow::Status; inline Status UnimplementedError(absl::string_view message) { return Status(error::UNIMPLEMENTED, message); } +inline Status InvalidArgumentError(absl::string_view message) { + return Status(error::INVALID_ARGUMENT, message); +} inline Status InternalError(absl::string_view message) { return Status(error::INTERNAL, message); }