[SE] Make ExecutorCache thread-safe, change ExecutorCache::Insert to ExecutorCache::GetOrCreate. Add support for creating Executors for different device ordinals in parallel.

[XLA] Create Executors in parallel.

PiperOrigin-RevId: 163734988
This commit is contained in:
Peter Hawkins 2017-07-31 12:58:44 -07:00 committed by TensorFlower Gardener
parent 7ebed6678c
commit 122750a879
9 changed files with 120 additions and 103 deletions

View File

@ -64,23 +64,8 @@ ExecutorPlatform::ExecutorForDeviceWithPluginConfig(
port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor(
const StreamExecutorConfig& config) {
mutex_lock lock(executors_mutex_);
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
if (status.ok()) {
return status.ValueOrDie();
}
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
GetUncachedExecutor(config);
if (!executor.ok()) {
return executor.status();
}
StreamExecutor* naked_executor = executor.ValueOrDie().get();
SE_RETURN_IF_ERROR(
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
return naked_executor;
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}
port::StatusOr<std::unique_ptr<StreamExecutor>>

View File

@ -67,9 +67,6 @@ class ExecutorPlatform : public Platform {
// This platform's name.
string name_;
// mutex that guards the ordinal-to-executor map.
mutable mutex executors_mutex_;
// Cache of created StreamExecutors.
ExecutorCache executor_cache_;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -140,21 +141,32 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
device_count = 1;
}
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
for (int i = 0; i < device_count; ++i) {
se::StreamExecutorConfig config;
config.ordinal = i;
auto executor_status = platform->GetExecutor(config);
if (executor_status.ok()) {
se::StreamExecutor* executor = executor_status.ValueOrDie();
if (IsDeviceSupported(executor)) {
stream_executors[i] = executor;
}
} else {
LOG(WARNING) << "unable to create StreamExecutor for " << platform->Name()
<< ":" << i << ": "
<< executor_status.status().error_message();
VLOG(1) << "Initializing devices";
{
tensorflow::thread::ThreadPool thread_pool(
tensorflow::Env::Default(), "device_initialization", device_count);
for (int i = 0; i < device_count; ++i) {
thread_pool.Schedule([platform, i, &stream_executors]() {
VLOG(1) << "Started device init " << i;
se::StreamExecutorConfig config;
config.ordinal = i;
auto executor_status = platform->GetExecutor(config);
if (executor_status.ok()) {
se::StreamExecutor* executor = executor_status.ValueOrDie();
if (IsDeviceSupported(executor)) {
stream_executors[i] = executor;
}
} else {
LOG(WARNING) << "unable to create StreamExecutor for "
<< platform->Name() << ":" << i << ": "
<< executor_status.status().error_message();
}
VLOG(1) << "Finished device init " << i;
});
}
// Block here in thread_pool destructor until all devices are initialized.
}
VLOG(1) << "Device initialization complete";
if (std::all_of(stream_executors.begin(), stream_executors.end(),
[](se::StreamExecutor* s) { return s == nullptr; })) {
return InternalError("no supported devices found for platform %s",

View File

@ -127,23 +127,8 @@ port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
const StreamExecutorConfig& config) {
mutex_lock lock(mu_);
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
if (status.ok()) {
return status.ValueOrDie();
}
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
GetUncachedExecutor(config);
if (!executor.ok()) {
return executor.status();
}
StreamExecutor* naked_executor = executor.ValueOrDie().get();
SE_RETURN_IF_ERROR(
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
return naked_executor;
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}
port::StatusOr<std::unique_ptr<StreamExecutor>>

View File

@ -88,9 +88,6 @@ class CudaPlatform : public Platform {
// This platform's name.
string name_;
// mutex that guards internal state.
mutable mutex mu_;
// Cache of created executors.
ExecutorCache executor_cache_;

View File

@ -20,39 +20,76 @@ limitations under the License.
namespace perftools {
namespace gputools {
port::Status ExecutorCache::Insert(const StreamExecutorConfig& config,
std::unique_ptr<StreamExecutor> entry) {
if (Get(config).ok()) {
return port::Status(port::error::ALREADY_EXISTS,
"An executor with a matching config already exists.");
port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
const StreamExecutorConfig& config,
const std::function<ExecutorFactory>& factory) {
Entry* entry = nullptr;
{
mutex_lock lock{mutex_};
entry = &cache_[config.ordinal];
// Release the map lock; the address of 'entry' is stable because
// std::map guarantees reference stability.
}
cache_[config.ordinal].emplace_back(Entry(config, std::move(entry)));
return port::Status::OK();
}
port::StatusOr<StreamExecutor*> ExecutorCache::Get(
const StreamExecutorConfig& config) {
auto entries = cache_.find(config.ordinal);
if (entries == cache_.end()) {
return port::Status(
port::error::NOT_FOUND,
port::Printf("No executors registered for ordinal %d", config.ordinal));
}
for (const auto& iter : entries->second) {
// Acquire the per-Entry mutex without holding the map mutex. Initializing
// an Executor may be expensive, so we want to allow concurrent
// initialization of different entries.
mutex_lock lock{entry->configurations_mutex};
for (const auto& iter : entry->configurations) {
if (iter.first.plugin_config == config.plugin_config &&
iter.first.device_options == config.device_options) {
VLOG(2) << "hit in cache";
return iter.second.get();
}
}
VLOG(2) << "building executor";
port::StatusOr<std::unique_ptr<StreamExecutor>> result = factory();
if (!result.ok()) {
VLOG(2) << "failed to get build executor: " << result.status();
// If construction failed, leave the cache Entry around, but with a null
// executor.
return result.status();
}
entry->configurations.emplace_back(config, std::move(result.ValueOrDie()));
return entry->configurations.back().second.get();
}
port::StatusOr<StreamExecutor*> ExecutorCache::Get(
const StreamExecutorConfig& config) {
Entry* entry = nullptr;
{
mutex_lock lock{mutex_};
entry = &cache_[config.ordinal];
// Release the map lock; the address of 'entry' is stable because
// std::map guarantees reference stability.
}
mutex_lock lock{entry->configurations_mutex};
if (entry->configurations.empty()) {
return port::Status(
port::error::NOT_FOUND,
port::Printf("No executors registered for ordinal %d", config.ordinal));
}
for (const auto& iter : entry->configurations) {
if (iter.first.plugin_config == config.plugin_config &&
iter.first.device_options == config.device_options) {
VLOG(2) << "hit in cache for device ordinal " << config.ordinal;
return iter.second.get();
}
}
return port::Status(port::error::NOT_FOUND,
"No executor found with a matching config.");
}
void ExecutorCache::DestroyAllExecutors() { cache_.clear(); }
void ExecutorCache::DestroyAllExecutors() {
mutex_lock lock{mutex_};
cache_.clear();
}
ExecutorCache::Entry::~Entry() {
mutex_lock lock{configurations_mutex};
configurations.clear();
}
} // namespace gputools
} // namespace perftools

View File

@ -16,40 +16,62 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
#include <functional>
#include <map>
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
namespace perftools {
namespace gputools {
// Utility class to allow Platform objects to manage cached StreamExecutors.
// Thread-safe.
class ExecutorCache {
public:
ExecutorCache() {}
// Inserts a new StreamExecutor with the given configuration into the cache.
// Will not overwrite if called when a matching element is already present.
port::Status Insert(const StreamExecutorConfig& config,
std::unique_ptr<StreamExecutor> executor);
// Looks up 'config' in the cache. Returns a pointer to the existing executor,
// if already present, or creates it using 'factory', if it does not.
// Factories may be executed concurrently for different device ordinals.
typedef port::StatusOr<std::unique_ptr<StreamExecutor>> ExecutorFactory();
port::StatusOr<StreamExecutor*> GetOrCreate(
const StreamExecutorConfig& config,
const std::function<ExecutorFactory>& factory);
// Returns a pointer to the described executor (if one with a matching config
// has been created), or a NOT_FOUND status.
port::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
// Destroys all Executors and clears the cache.
// Performs no synchronization - undefined behavior may occur if any executors
// are active!
// Performs no synchronization with the executors - undefined behavior may
// occur if any executors are active!
void DestroyAllExecutors();
private:
typedef std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>
Entry;
// Each Entry contains zero or more cached executors for a device ordinal.
struct Entry {
~Entry();
// Mutex that locks the contents of each entry. The 'mutex_' of the
// ExecutorCache class protects both the 'cache_' and the existence of each
// Entry, but not the Entry's contents. 'configurations_mutex' protects the
// contents of the entry after 'mutex_' has been dropped.
mutex configurations_mutex;
// Vector of cached {config, executor} pairs.
std::vector<
std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>>
configurations GUARDED_BY(configurations_mutex);
};
// Maps ordinal number to a list of cached executors for that ordinal.
// We key off of ordinal (instead of just looking up all fields in the
// StreamExecutorConfig) for a slight improvement in lookup time.
std::map<int, std::vector<Entry>> cache_;
mutex mutex_;
std::map<int, Entry> cache_ GUARDED_BY(mutex_);
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
};

View File

@ -63,23 +63,8 @@ port::StatusOr<StreamExecutor*> HostPlatform::ExecutorForDeviceWithPluginConfig(
port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor(
const StreamExecutorConfig& config) {
mutex_lock lock(executors_mutex_);
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
if (status.ok()) {
return status.ValueOrDie();
}
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
GetUncachedExecutor(config);
if (!executor.ok()) {
return executor.status();
}
StreamExecutor* naked_executor = executor.ValueOrDie().get();
SE_RETURN_IF_ERROR(
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
return naked_executor;
return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
}
port::StatusOr<std::unique_ptr<StreamExecutor>>

View File

@ -72,9 +72,6 @@ class HostPlatform : public Platform {
// This platform's name.
string name_;
// mutex that guards the ordinal-to-executor map.
mutable mutex executors_mutex_;
// Cache of created StreamExecutors.
ExecutorCache executor_cache_;