[SE] Make ExecutorCache thread-safe, change ExecutorCache::Insert to ExecutorCache::GetOrCreate. Add support for creating Executors for different device ordinals in parallel.
[XLA] Create Executors in parallel. PiperOrigin-RevId: 163734988
This commit is contained in:
parent
7ebed6678c
commit
122750a879
tensorflow
compiler
stream_executor
@ -64,23 +64,8 @@ ExecutorPlatform::ExecutorForDeviceWithPluginConfig(
|
||||
|
||||
port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor(
|
||||
const StreamExecutorConfig& config) {
|
||||
mutex_lock lock(executors_mutex_);
|
||||
|
||||
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
|
||||
if (status.ok()) {
|
||||
return status.ValueOrDie();
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
|
||||
GetUncachedExecutor(config);
|
||||
if (!executor.ok()) {
|
||||
return executor.status();
|
||||
}
|
||||
|
||||
StreamExecutor* naked_executor = executor.ValueOrDie().get();
|
||||
SE_RETURN_IF_ERROR(
|
||||
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
|
||||
return naked_executor;
|
||||
return executor_cache_.GetOrCreate(
|
||||
config, [&]() { return GetUncachedExecutor(config); });
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>>
|
||||
|
@ -67,9 +67,6 @@ class ExecutorPlatform : public Platform {
|
||||
// This platform's name.
|
||||
string name_;
|
||||
|
||||
// mutex that guards the ordinal-to-executor map.
|
||||
mutable mutex executors_mutex_;
|
||||
|
||||
// Cache of created StreamExecutors.
|
||||
ExecutorCache executor_cache_;
|
||||
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -140,21 +141,32 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
|
||||
device_count = 1;
|
||||
}
|
||||
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
se::StreamExecutorConfig config;
|
||||
config.ordinal = i;
|
||||
auto executor_status = platform->GetExecutor(config);
|
||||
if (executor_status.ok()) {
|
||||
se::StreamExecutor* executor = executor_status.ValueOrDie();
|
||||
if (IsDeviceSupported(executor)) {
|
||||
stream_executors[i] = executor;
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << "unable to create StreamExecutor for " << platform->Name()
|
||||
<< ":" << i << ": "
|
||||
<< executor_status.status().error_message();
|
||||
VLOG(1) << "Initializing devices";
|
||||
{
|
||||
tensorflow::thread::ThreadPool thread_pool(
|
||||
tensorflow::Env::Default(), "device_initialization", device_count);
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
thread_pool.Schedule([platform, i, &stream_executors]() {
|
||||
VLOG(1) << "Started device init " << i;
|
||||
se::StreamExecutorConfig config;
|
||||
config.ordinal = i;
|
||||
auto executor_status = platform->GetExecutor(config);
|
||||
if (executor_status.ok()) {
|
||||
se::StreamExecutor* executor = executor_status.ValueOrDie();
|
||||
if (IsDeviceSupported(executor)) {
|
||||
stream_executors[i] = executor;
|
||||
}
|
||||
} else {
|
||||
LOG(WARNING) << "unable to create StreamExecutor for "
|
||||
<< platform->Name() << ":" << i << ": "
|
||||
<< executor_status.status().error_message();
|
||||
}
|
||||
VLOG(1) << "Finished device init " << i;
|
||||
});
|
||||
}
|
||||
// Block here in thread_pool destructor until all devices are initialized.
|
||||
}
|
||||
VLOG(1) << "Device initialization complete";
|
||||
if (std::all_of(stream_executors.begin(), stream_executors.end(),
|
||||
[](se::StreamExecutor* s) { return s == nullptr; })) {
|
||||
return InternalError("no supported devices found for platform %s",
|
||||
|
@ -127,23 +127,8 @@ port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
|
||||
|
||||
port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
|
||||
const StreamExecutorConfig& config) {
|
||||
mutex_lock lock(mu_);
|
||||
|
||||
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
|
||||
if (status.ok()) {
|
||||
return status.ValueOrDie();
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
|
||||
GetUncachedExecutor(config);
|
||||
if (!executor.ok()) {
|
||||
return executor.status();
|
||||
}
|
||||
|
||||
StreamExecutor* naked_executor = executor.ValueOrDie().get();
|
||||
SE_RETURN_IF_ERROR(
|
||||
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
|
||||
return naked_executor;
|
||||
return executor_cache_.GetOrCreate(
|
||||
config, [&]() { return GetUncachedExecutor(config); });
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>>
|
||||
|
@ -88,9 +88,6 @@ class CudaPlatform : public Platform {
|
||||
// This platform's name.
|
||||
string name_;
|
||||
|
||||
// mutex that guards internal state.
|
||||
mutable mutex mu_;
|
||||
|
||||
// Cache of created executors.
|
||||
ExecutorCache executor_cache_;
|
||||
|
||||
|
@ -20,39 +20,76 @@ limitations under the License.
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
|
||||
port::Status ExecutorCache::Insert(const StreamExecutorConfig& config,
|
||||
std::unique_ptr<StreamExecutor> entry) {
|
||||
if (Get(config).ok()) {
|
||||
return port::Status(port::error::ALREADY_EXISTS,
|
||||
"An executor with a matching config already exists.");
|
||||
port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
|
||||
const StreamExecutorConfig& config,
|
||||
const std::function<ExecutorFactory>& factory) {
|
||||
Entry* entry = nullptr;
|
||||
{
|
||||
mutex_lock lock{mutex_};
|
||||
entry = &cache_[config.ordinal];
|
||||
// Release the map lock; the address of 'entry' is stable because
|
||||
// std::map guarantees reference stability.
|
||||
}
|
||||
|
||||
cache_[config.ordinal].emplace_back(Entry(config, std::move(entry)));
|
||||
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
port::StatusOr<StreamExecutor*> ExecutorCache::Get(
|
||||
const StreamExecutorConfig& config) {
|
||||
auto entries = cache_.find(config.ordinal);
|
||||
if (entries == cache_.end()) {
|
||||
return port::Status(
|
||||
port::error::NOT_FOUND,
|
||||
port::Printf("No executors registered for ordinal %d", config.ordinal));
|
||||
}
|
||||
|
||||
for (const auto& iter : entries->second) {
|
||||
// Acquire the per-Entry mutex without holding the map mutex. Initializing
|
||||
// an Executor may be expensive, so we want to allow concurrent
|
||||
// initialization of different entries.
|
||||
mutex_lock lock{entry->configurations_mutex};
|
||||
for (const auto& iter : entry->configurations) {
|
||||
if (iter.first.plugin_config == config.plugin_config &&
|
||||
iter.first.device_options == config.device_options) {
|
||||
VLOG(2) << "hit in cache";
|
||||
return iter.second.get();
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "building executor";
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>> result = factory();
|
||||
if (!result.ok()) {
|
||||
VLOG(2) << "failed to get build executor: " << result.status();
|
||||
// If construction failed, leave the cache Entry around, but with a null
|
||||
// executor.
|
||||
return result.status();
|
||||
}
|
||||
entry->configurations.emplace_back(config, std::move(result.ValueOrDie()));
|
||||
return entry->configurations.back().second.get();
|
||||
}
|
||||
|
||||
port::StatusOr<StreamExecutor*> ExecutorCache::Get(
|
||||
const StreamExecutorConfig& config) {
|
||||
Entry* entry = nullptr;
|
||||
{
|
||||
mutex_lock lock{mutex_};
|
||||
entry = &cache_[config.ordinal];
|
||||
// Release the map lock; the address of 'entry' is stable because
|
||||
// std::map guarantees reference stability.
|
||||
}
|
||||
mutex_lock lock{entry->configurations_mutex};
|
||||
if (entry->configurations.empty()) {
|
||||
return port::Status(
|
||||
port::error::NOT_FOUND,
|
||||
port::Printf("No executors registered for ordinal %d", config.ordinal));
|
||||
}
|
||||
for (const auto& iter : entry->configurations) {
|
||||
if (iter.first.plugin_config == config.plugin_config &&
|
||||
iter.first.device_options == config.device_options) {
|
||||
VLOG(2) << "hit in cache for device ordinal " << config.ordinal;
|
||||
return iter.second.get();
|
||||
}
|
||||
}
|
||||
return port::Status(port::error::NOT_FOUND,
|
||||
"No executor found with a matching config.");
|
||||
}
|
||||
|
||||
void ExecutorCache::DestroyAllExecutors() { cache_.clear(); }
|
||||
void ExecutorCache::DestroyAllExecutors() {
|
||||
mutex_lock lock{mutex_};
|
||||
cache_.clear();
|
||||
}
|
||||
|
||||
ExecutorCache::Entry::~Entry() {
|
||||
mutex_lock lock{configurations_mutex};
|
||||
configurations.clear();
|
||||
}
|
||||
|
||||
} // namespace gputools
|
||||
} // namespace perftools
|
||||
|
@ -16,40 +16,62 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
|
||||
#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
#include "tensorflow/stream_executor/lib/status.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
#include "tensorflow/stream_executor/platform/mutex.h"
|
||||
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
||||
|
||||
namespace perftools {
|
||||
namespace gputools {
|
||||
|
||||
// Utility class to allow Platform objects to manage cached StreamExecutors.
|
||||
// Thread-safe.
|
||||
class ExecutorCache {
|
||||
public:
|
||||
ExecutorCache() {}
|
||||
|
||||
// Inserts a new StreamExecutor with the given configuration into the cache.
|
||||
// Will not overwrite if called when a matching element is already present.
|
||||
port::Status Insert(const StreamExecutorConfig& config,
|
||||
std::unique_ptr<StreamExecutor> executor);
|
||||
// Looks up 'config' in the cache. Returns a pointer to the existing executor,
|
||||
// if already present, or creates it using 'factory', if it does not.
|
||||
// Factories may be executed concurrently for different device ordinals.
|
||||
typedef port::StatusOr<std::unique_ptr<StreamExecutor>> ExecutorFactory();
|
||||
port::StatusOr<StreamExecutor*> GetOrCreate(
|
||||
const StreamExecutorConfig& config,
|
||||
const std::function<ExecutorFactory>& factory);
|
||||
|
||||
// Returns a pointer to the described executor (if one with a matching config
|
||||
// has been created), or a NOT_FOUND status.
|
||||
port::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
|
||||
|
||||
// Destroys all Executors and clears the cache.
|
||||
// Performs no synchronization - undefined behavior may occur if any executors
|
||||
// are active!
|
||||
// Performs no synchronization with the executors - undefined behavior may
|
||||
// occur if any executors are active!
|
||||
void DestroyAllExecutors();
|
||||
|
||||
private:
|
||||
typedef std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>
|
||||
Entry;
|
||||
// Each Entry contains zero or more cached executors for a device ordinal.
|
||||
struct Entry {
|
||||
~Entry();
|
||||
|
||||
// Mutex that locks the contents of each entry. The 'mutex_' of the
|
||||
// ExecutorCache class protects both the 'cache_' and the existence of each
|
||||
// Entry, but not the Entry's contents. 'configurations_mutex' protects the
|
||||
// contents of the entry after 'mutex_' has been dropped.
|
||||
mutex configurations_mutex;
|
||||
|
||||
// Vector of cached {config, executor} pairs.
|
||||
std::vector<
|
||||
std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>>
|
||||
configurations GUARDED_BY(configurations_mutex);
|
||||
};
|
||||
|
||||
// Maps ordinal number to a list of cached executors for that ordinal.
|
||||
// We key off of ordinal (instead of just looking up all fields in the
|
||||
// StreamExecutorConfig) for a slight improvement in lookup time.
|
||||
std::map<int, std::vector<Entry>> cache_;
|
||||
mutex mutex_;
|
||||
std::map<int, Entry> cache_ GUARDED_BY(mutex_);
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
|
||||
};
|
||||
|
@ -63,23 +63,8 @@ port::StatusOr<StreamExecutor*> HostPlatform::ExecutorForDeviceWithPluginConfig(
|
||||
|
||||
port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor(
|
||||
const StreamExecutorConfig& config) {
|
||||
mutex_lock lock(executors_mutex_);
|
||||
|
||||
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
|
||||
if (status.ok()) {
|
||||
return status.ValueOrDie();
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
|
||||
GetUncachedExecutor(config);
|
||||
if (!executor.ok()) {
|
||||
return executor.status();
|
||||
}
|
||||
|
||||
StreamExecutor* naked_executor = executor.ValueOrDie().get();
|
||||
SE_RETURN_IF_ERROR(
|
||||
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
|
||||
return naked_executor;
|
||||
return executor_cache_.GetOrCreate(
|
||||
config, [&]() { return GetUncachedExecutor(config); });
|
||||
}
|
||||
|
||||
port::StatusOr<std::unique_ptr<StreamExecutor>>
|
||||
|
@ -72,9 +72,6 @@ class HostPlatform : public Platform {
|
||||
// This platform's name.
|
||||
string name_;
|
||||
|
||||
// mutex that guards the ordinal-to-executor map.
|
||||
mutable mutex executors_mutex_;
|
||||
|
||||
// Cache of created StreamExecutors.
|
||||
ExecutorCache executor_cache_;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user