[XLA:Python] Specify a 2MiB stack size for host stream threads.
[StreamExecutor] Allow HostExecutor users to control the stack sizes of threads used for HostStream via. Also include non_portable_tags in the keys used when creating an Executor. There seems to be no good reason that it is omitted. Will fix https://github.com/google/jax/issues/432 when included in a jaxlib release. PiperOrigin-RevId: 309472318 Change-Id: Ia2535616047390d6bf6f2da82a666a321dcc9f5d
This commit is contained in:
parent
88a2edd2c7
commit
077b553fda
@ -167,6 +167,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
|
||||
@ -40,8 +41,12 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
for (int i = 0; i < client->device_count(); ++i) {
|
||||
se::StreamExecutor* executor =
|
||||
client->backend().stream_executor(i).ValueOrDie();
|
||||
se::StreamExecutorConfig config;
|
||||
config.ordinal = i;
|
||||
config.device_options.non_portable_tags["host_thread_stack_size_in_bytes"] =
|
||||
absl::StrCat(2048 * 1024);
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
platform->GetExecutor(config));
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
executor, client, LocalDeviceState::kSynchronous, asynchronous,
|
||||
/*allow_event_reuse=*/false);
|
||||
|
@ -203,7 +203,8 @@ class XlaInterpreterExecutor : public internal::StreamExecutorInterface {
|
||||
|
||||
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
|
||||
override {
|
||||
return std::unique_ptr<internal::StreamInterface>(new host::HostStream());
|
||||
return std::unique_ptr<internal::StreamInterface>(
|
||||
new host::HostStream(/*thread_stack_size=*/0));
|
||||
}
|
||||
|
||||
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
|
||||
|
@ -64,7 +64,8 @@ struct DeviceOptions {
|
||||
unsigned flags() const { return flags_; }
|
||||
|
||||
bool operator==(const DeviceOptions& other) const {
|
||||
return flags_ == other.flags_;
|
||||
return flags_ == other.flags_ &&
|
||||
non_portable_tags == other.non_portable_tags;
|
||||
}
|
||||
|
||||
bool operator!=(const DeviceOptions& other) const {
|
||||
|
@ -112,6 +112,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl",
|
||||
"//tensorflow/stream_executor:timer",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
alwayslink = True,
|
||||
|
@ -19,6 +19,8 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||
@ -42,6 +44,20 @@ HostExecutor::HostExecutor(const PluginConfig &plugin_config)
|
||||
|
||||
HostExecutor::~HostExecutor() {}
|
||||
|
||||
port::Status HostExecutor::Init(int device_ordinal,
|
||||
DeviceOptions device_options) {
|
||||
auto it =
|
||||
device_options.non_portable_tags.find("host_thread_stack_size_in_bytes");
|
||||
if (it != device_options.non_portable_tags.end()) {
|
||||
if (!absl::SimpleAtoi(it->second, &thread_stack_size_in_bytes_)) {
|
||||
return port::InvalidArgumentError(absl::StrCat(
|
||||
"Unable to parse host_thread_stack_size_in_bytes as an integer: ",
|
||||
it->second));
|
||||
}
|
||||
}
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
DeviceMemoryBase HostExecutor::Allocate(uint64 size, int64 memory_space) {
|
||||
CHECK_EQ(memory_space, 0);
|
||||
// Use a minimum alignment of 64 bytes to be friendly to AVX512 code.
|
||||
@ -332,5 +348,11 @@ rng::RngSupport *HostExecutor::CreateRng() {
|
||||
return status.ValueOrDie()(this);
|
||||
}
|
||||
|
||||
std::unique_ptr<internal::StreamInterface>
|
||||
HostExecutor::GetStreamImplementation() {
|
||||
return std::unique_ptr<internal::StreamInterface>(
|
||||
new HostStream(thread_stack_size_in_bytes_));
|
||||
}
|
||||
|
||||
} // namespace host
|
||||
} // namespace stream_executor
|
||||
|
@ -46,9 +46,9 @@ class HostExecutor : public internal::StreamExecutorInterface {
|
||||
explicit HostExecutor(const PluginConfig &plugin_config);
|
||||
~HostExecutor() override;
|
||||
|
||||
port::Status Init(int device_ordinal, DeviceOptions device_options) override {
|
||||
return port::Status::OK();
|
||||
}
|
||||
// The stack size used for host streams can be set via
|
||||
// device_options.non_portable_tags["host_stack_size"].
|
||||
port::Status Init(int device_ordinal, DeviceOptions device_options) override;
|
||||
|
||||
port::Status GetKernel(const MultiKernelLoaderSpec &spec,
|
||||
KernelBase *kernel) override {
|
||||
@ -184,10 +184,7 @@ class HostExecutor : public internal::StreamExecutorInterface {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
|
||||
override {
|
||||
return std::unique_ptr<internal::StreamInterface>(new HostStream());
|
||||
}
|
||||
std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override;
|
||||
|
||||
std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
|
||||
return std::unique_ptr<internal::TimerInterface>(new HostTimer());
|
||||
@ -197,6 +194,8 @@ class HostExecutor : public internal::StreamExecutorInterface {
|
||||
|
||||
private:
|
||||
const PluginConfig plugin_config_;
|
||||
// Size of thread stacks for streams in bytes. '0' means "the default size".
|
||||
size_t thread_stack_size_in_bytes_ = 0;
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
|
@ -24,9 +24,20 @@ limitations under the License.
|
||||
namespace stream_executor {
|
||||
namespace host {
|
||||
|
||||
HostStream::HostStream()
|
||||
namespace {
|
||||
|
||||
port::ThreadOptions GetThreadOptions(size_t stack_size_in_bytes) {
|
||||
port::ThreadOptions options;
|
||||
options.stack_size = stack_size_in_bytes;
|
||||
return options;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
HostStream::HostStream(size_t stack_size_in_bytes)
|
||||
: thread_(port::Env::Default()->StartThread(
|
||||
port::ThreadOptions(), "host_executor", [this]() { WorkLoop(); })) {}
|
||||
GetThreadOptions(stack_size_in_bytes), "host_executor",
|
||||
[this]() { WorkLoop(); })) {}
|
||||
|
||||
HostStream::~HostStream() {
|
||||
{
|
||||
|
@ -31,7 +31,9 @@ namespace host {
|
||||
|
||||
class HostStream : public internal::StreamInterface {
|
||||
public:
|
||||
HostStream();
|
||||
// stack_size_in_bytes may be '0', meaning "use the default thread stack
|
||||
// size".
|
||||
explicit HostStream(size_t stack_size_in_bytes);
|
||||
~HostStream() override;
|
||||
|
||||
bool EnqueueTask(std::function<void()> task);
|
||||
|
@ -36,6 +36,9 @@ using Status = tensorflow::Status;
|
||||
inline Status UnimplementedError(absl::string_view message) {
|
||||
return Status(error::UNIMPLEMENTED, message);
|
||||
}
|
||||
inline Status InvalidArgumentError(absl::string_view message) {
|
||||
return Status(error::INVALID_ARGUMENT, message);
|
||||
}
|
||||
inline Status InternalError(absl::string_view message) {
|
||||
return Status(error::INTERNAL, message);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user