[XLA:Python] Specify a 2MiB stack size for host stream threads.

[StreamExecutor] Allow HostExecutor users to control the stack sizes of threads used for HostStream via.

Also include non_portable_tags in the keys used when creating an Executor. There seems to be no good reason that it is omitted.

Will fix https://github.com/google/jax/issues/432 when included in a jaxlib release.

PiperOrigin-RevId: 309472318
Change-Id: Ia2535616047390d6bf6f2da82a666a321dcc9f5d
This commit is contained in:
Peter Hawkins 2020-05-01 14:16:44 -07:00 committed by TensorFlower Gardener
parent 88a2edd2c7
commit 077b553fda
10 changed files with 60 additions and 14 deletions

View File

@ -167,6 +167,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/service:platform_util",
"@com_google_absl//absl/strings",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@ -40,8 +41,12 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
std::vector<std::unique_ptr<Device>> devices;
for (int i = 0; i < client->device_count(); ++i) {
se::StreamExecutor* executor =
client->backend().stream_executor(i).ValueOrDie();
se::StreamExecutorConfig config;
config.ordinal = i;
config.device_options.non_portable_tags["host_thread_stack_size_in_bytes"] =
absl::StrCat(2048 * 1024);
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
platform->GetExecutor(config));
auto device_state = absl::make_unique<LocalDeviceState>(
executor, client, LocalDeviceState::kSynchronous, asynchronous,
/*allow_event_reuse=*/false);

View File

@ -203,7 +203,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
return std::unique_ptr<internal::StreamInterface>(new host::HostStream());
return std::unique_ptr<internal::StreamInterface>(
new host::HostStream(/*thread_stack_size=*/0));
}
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {

View File

@ -64,7 +64,8 @@ struct DeviceOptions {
unsigned flags() const { return flags_; }
bool operator==(const DeviceOptions& other) const {
return flags_ == other.flags_;
return flags_ == other.flags_ &&
non_portable_tags == other.non_portable_tags;
}
bool operator!=(const DeviceOptions& other) const {

View File

@ -112,6 +112,7 @@ cc_library(
"//tensorflow/stream_executor:stream_executor_pimpl",
"//tensorflow/stream_executor:timer",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
alwayslink = True,

View File

@ -19,6 +19,8 @@ limitations under the License.
#include <string.h>
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
@ -42,6 +44,20 @@ HostExecutor::HostExecutor(const PluginConfig &plugin_config)
HostExecutor::~HostExecutor() {}
port::Status HostExecutor::Init(int device_ordinal,
DeviceOptions device_options) {
auto it =
device_options.non_portable_tags.find("host_thread_stack_size_in_bytes");
if (it != device_options.non_portable_tags.end()) {
if (!absl::SimpleAtoi(it->second, &thread_stack_size_in_bytes_)) {
return port::InvalidArgumentError(absl::StrCat(
"Unable to parse host_thread_stack_size_in_bytes as an integer: ",
it->second));
}
}
return port::Status::OK();
}
DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) {
CHECK_EQ(memory_space, 0);
// Use a minimum alignment of 64 bytes to be friendly to AVX512 code.
@ -332,5 +348,11 @@ rng::RngSupport *HostExecutor::CreateRng() {
return status.ValueOrDie()(this);
}
std::unique_ptr<internal::StreamInterface>
HostExecutor::GetStreamImplementation() {
return std::unique_ptr<internal::StreamInterface>(
new HostStream(thread_stack_size_in_bytes_));
}
} // namespace host
} // namespace stream_executor

View File

@ -46,9 +46,9 @@ class HostExecutor : public internal::StreamExecutorInterface {
explicit HostExecutor(const PluginConfig &plugin_config);
~HostExecutor() override;
port::Status Init(int device_ordinal, DeviceOptions device_options) override {
return port::Status::OK();
}
// The stack size used for host streams can be set via
// device_options.non_portable_tags["host_stack_size"].
port::Status Init(int device_ordinal, DeviceOptions device_options) override;
port::Status GetKernel(const MultiKernelLoaderSpec &spec,
KernelBase *kernel) override {
@ -184,10 +184,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
return nullptr;
}
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
override {
return std::unique_ptr<internal::StreamInterface>(new HostStream());
}
std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override;
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
return std::unique_ptr<internal::TimerInterface>(new HostTimer());
@ -197,6 +194,8 @@ class HostExecutor : public internal::StreamExecutorInterface {
private:
const PluginConfig plugin_config_;
// Size of thread stacks for streams in bytes. '0' means "the default size".
size_t thread_stack_size_in_bytes_ = 0;
};
} // namespace host

View File

@ -24,9 +24,20 @@ limitations under the License.
namespace stream_executor {
namespace host {
HostStream::HostStream()
namespace {
port::ThreadOptions GetThreadOptions(size_t stack_size_in_bytes) {
port::ThreadOptions options;
options.stack_size = stack_size_in_bytes;
return options;
}
} // namespace
HostStream::HostStream(size_t stack_size_in_bytes)
: thread_(port::Env::Default()->StartThread(
port::ThreadOptions(), "host_executor", [this]() { WorkLoop(); })) {}
GetThreadOptions(stack_size_in_bytes), "host_executor",
[this]() { WorkLoop(); })) {}
HostStream::~HostStream() {
{

View File

@ -31,7 +31,9 @@ namespace host {
class HostStream : public internal::StreamInterface {
public:
HostStream();
// stack_size_in_bytes may be '0', meaning "use the default thread stack
// size".
explicit HostStream(size_t stack_size_in_bytes);
~HostStream() override;
bool EnqueueTask(std::function<void()> task);

View File

@ -36,6 +36,9 @@ using Status = tensorflow::Status;
inline Status UnimplementedError(absl::string_view message) {
return Status(error::UNIMPLEMENTED, message);
}
inline Status InvalidArgumentError(absl::string_view message) {
return Status(error::INVALID_ARGUMENT, message);
}
inline Status InternalError(absl::string_view message) {
return Status(error::INTERNAL, message);
}