[SE] Make ExecutorCache thread-safe, change ExecutorCache::Insert to ExecutorCache::GetOrCreate. Add support for creating Executors for different device ordinals in parallel.

[XLA] Create Executors in parallel.

PiperOrigin-RevId: 163734988
This commit is contained in:
Peter Hawkins 2017-07-31 12:58:44 -07:00 committed by TensorFlower Gardener
parent 7ebed6678c
commit 122750a879
9 changed files with 120 additions and 103 deletions

View File

@ -64,23 +64,8 @@ ExecutorPlatform::ExecutorForDeviceWithPluginConfig(
port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor( port::StatusOr<StreamExecutor*> ExecutorPlatform::GetExecutor(
const StreamExecutorConfig& config) { const StreamExecutorConfig& config) {
mutex_lock lock(executors_mutex_); return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
if (status.ok()) {
return status.ValueOrDie();
}
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
GetUncachedExecutor(config);
if (!executor.ok()) {
return executor.status();
}
StreamExecutor* naked_executor = executor.ValueOrDie().get();
SE_RETURN_IF_ERROR(
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
return naked_executor;
} }
port::StatusOr<std::unique_ptr<StreamExecutor>> port::StatusOr<std::unique_ptr<StreamExecutor>>

View File

@ -67,9 +67,6 @@ class ExecutorPlatform : public Platform {
// This platform's name. // This platform's name.
string name_; string name_;
// mutex that guards the ordinal-to-executor map.
mutable mutex executors_mutex_;
// Cache of created StreamExecutors. // Cache of created StreamExecutors.
ExecutorCache executor_cache_; ExecutorCache executor_cache_;

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -140,7 +141,13 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
device_count = 1; device_count = 1;
} }
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr); std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
VLOG(1) << "Initializing devices";
{
tensorflow::thread::ThreadPool thread_pool(
tensorflow::Env::Default(), "device_initialization", device_count);
for (int i = 0; i < device_count; ++i) { for (int i = 0; i < device_count; ++i) {
thread_pool.Schedule([platform, i, &stream_executors]() {
VLOG(1) << "Started device init " << i;
se::StreamExecutorConfig config; se::StreamExecutorConfig config;
config.ordinal = i; config.ordinal = i;
auto executor_status = platform->GetExecutor(config); auto executor_status = platform->GetExecutor(config);
@ -150,11 +157,16 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
stream_executors[i] = executor; stream_executors[i] = executor;
} }
} else { } else {
LOG(WARNING) << "unable to create StreamExecutor for " << platform->Name() LOG(WARNING) << "unable to create StreamExecutor for "
<< ":" << i << ": " << platform->Name() << ":" << i << ": "
<< executor_status.status().error_message(); << executor_status.status().error_message();
} }
VLOG(1) << "Finished device init " << i;
});
} }
// Block here in thread_pool destructor until all devices are initialized.
}
VLOG(1) << "Device initialization complete";
if (std::all_of(stream_executors.begin(), stream_executors.end(), if (std::all_of(stream_executors.begin(), stream_executors.end(),
[](se::StreamExecutor* s) { return s == nullptr; })) { [](se::StreamExecutor* s) { return s == nullptr; })) {
return InternalError("no supported devices found for platform %s", return InternalError("no supported devices found for platform %s",

View File

@ -127,23 +127,8 @@ port::StatusOr<StreamExecutor*> CudaPlatform::ExecutorForDeviceWithPluginConfig(
port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor( port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
const StreamExecutorConfig& config) { const StreamExecutorConfig& config) {
mutex_lock lock(mu_); return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
if (status.ok()) {
return status.ValueOrDie();
}
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
GetUncachedExecutor(config);
if (!executor.ok()) {
return executor.status();
}
StreamExecutor* naked_executor = executor.ValueOrDie().get();
SE_RETURN_IF_ERROR(
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
return naked_executor;
} }
port::StatusOr<std::unique_ptr<StreamExecutor>> port::StatusOr<std::unique_ptr<StreamExecutor>>

View File

@ -88,9 +88,6 @@ class CudaPlatform : public Platform {
// This platform's name. // This platform's name.
string name_; string name_;
// mutex that guards internal state.
mutable mutex mu_;
// Cache of created executors. // Cache of created executors.
ExecutorCache executor_cache_; ExecutorCache executor_cache_;

View File

@ -20,39 +20,76 @@ limitations under the License.
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
port::Status ExecutorCache::Insert(const StreamExecutorConfig& config, port::StatusOr<StreamExecutor*> ExecutorCache::GetOrCreate(
std::unique_ptr<StreamExecutor> entry) { const StreamExecutorConfig& config,
if (Get(config).ok()) { const std::function<ExecutorFactory>& factory) {
return port::Status(port::error::ALREADY_EXISTS, Entry* entry = nullptr;
"An executor with a matching config already exists."); {
mutex_lock lock{mutex_};
entry = &cache_[config.ordinal];
// Release the map lock; the address of 'entry' is stable because
// std::map guarantees reference stability.
} }
cache_[config.ordinal].emplace_back(Entry(config, std::move(entry))); // Acquire the per-Entry mutex without holding the map mutex. Initializing
// an Executor may be expensive, so we want to allow concurrent
return port::Status::OK(); // initialization of different entries.
} mutex_lock lock{entry->configurations_mutex};
for (const auto& iter : entry->configurations) {
port::StatusOr<StreamExecutor*> ExecutorCache::Get(
const StreamExecutorConfig& config) {
auto entries = cache_.find(config.ordinal);
if (entries == cache_.end()) {
return port::Status(
port::error::NOT_FOUND,
port::Printf("No executors registered for ordinal %d", config.ordinal));
}
for (const auto& iter : entries->second) {
if (iter.first.plugin_config == config.plugin_config && if (iter.first.plugin_config == config.plugin_config &&
iter.first.device_options == config.device_options) { iter.first.device_options == config.device_options) {
VLOG(2) << "hit in cache";
return iter.second.get(); return iter.second.get();
} }
} }
VLOG(2) << "building executor";
port::StatusOr<std::unique_ptr<StreamExecutor>> result = factory();
if (!result.ok()) {
VLOG(2) << "failed to get build executor: " << result.status();
// If construction failed, leave the cache Entry around, but with a null
// executor.
return result.status();
}
entry->configurations.emplace_back(config, std::move(result.ValueOrDie()));
return entry->configurations.back().second.get();
}
port::StatusOr<StreamExecutor*> ExecutorCache::Get(
const StreamExecutorConfig& config) {
Entry* entry = nullptr;
{
mutex_lock lock{mutex_};
entry = &cache_[config.ordinal];
// Release the map lock; the address of 'entry' is stable because
// std::map guarantees reference stability.
}
mutex_lock lock{entry->configurations_mutex};
if (entry->configurations.empty()) {
return port::Status(
port::error::NOT_FOUND,
port::Printf("No executors registered for ordinal %d", config.ordinal));
}
for (const auto& iter : entry->configurations) {
if (iter.first.plugin_config == config.plugin_config &&
iter.first.device_options == config.device_options) {
VLOG(2) << "hit in cache for device ordinal " << config.ordinal;
return iter.second.get();
}
}
return port::Status(port::error::NOT_FOUND, return port::Status(port::error::NOT_FOUND,
"No executor found with a matching config."); "No executor found with a matching config.");
} }
void ExecutorCache::DestroyAllExecutors() { cache_.clear(); } void ExecutorCache::DestroyAllExecutors() {
mutex_lock lock{mutex_};
cache_.clear();
}
ExecutorCache::Entry::~Entry() {
mutex_lock lock{configurations_mutex};
configurations.clear();
}
} // namespace gputools } // namespace gputools
} // namespace perftools } // namespace perftools

View File

@ -16,40 +16,62 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_ #ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_ #define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
#include <functional>
#include <map>
#include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h" #include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h"
namespace perftools { namespace perftools {
namespace gputools { namespace gputools {
// Utility class to allow Platform objects to manage cached StreamExecutors. // Utility class to allow Platform objects to manage cached StreamExecutors.
// Thread-safe.
class ExecutorCache { class ExecutorCache {
public: public:
ExecutorCache() {} ExecutorCache() {}
// Inserts a new StreamExecutor with the given configuration into the cache. // Looks up 'config' in the cache. Returns a pointer to the existing executor,
// Will not overwrite if called when a matching element is already present. // if already present, or creates it using 'factory', if it does not.
port::Status Insert(const StreamExecutorConfig& config, // Factories may be executed concurrently for different device ordinals.
std::unique_ptr<StreamExecutor> executor); typedef port::StatusOr<std::unique_ptr<StreamExecutor>> ExecutorFactory();
port::StatusOr<StreamExecutor*> GetOrCreate(
const StreamExecutorConfig& config,
const std::function<ExecutorFactory>& factory);
// Returns a pointer to the described executor (if one with a matching config // Returns a pointer to the described executor (if one with a matching config
// has been created), or a NOT_FOUND status. // has been created), or a NOT_FOUND status.
port::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config); port::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
// Destroys all Executors and clears the cache. // Destroys all Executors and clears the cache.
// Performs no synchronization - undefined behavior may occur if any executors // Performs no synchronization with the executors - undefined behavior may
// are active! // occur if any executors are active!
void DestroyAllExecutors(); void DestroyAllExecutors();
private: private:
typedef std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>> // Each Entry contains zero or more cached executors for a device ordinal.
Entry; struct Entry {
~Entry();
// Mutex that locks the contents of each entry. The 'mutex_' of the
// ExecutorCache class protects both the 'cache_' and the existence of each
// Entry, but not the Entry's contents. 'configurations_mutex' protects the
// contents of the entry after 'mutex_' has been dropped.
mutex configurations_mutex;
// Vector of cached {config, executor} pairs.
std::vector<
std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>>
configurations GUARDED_BY(configurations_mutex);
};
// Maps ordinal number to a list of cached executors for that ordinal. // Maps ordinal number to a list of cached executors for that ordinal.
// We key off of ordinal (instead of just looking up all fields in the // We key off of ordinal (instead of just looking up all fields in the
// StreamExecutorConfig) for a slight improvement in lookup time. // StreamExecutorConfig) for a slight improvement in lookup time.
std::map<int, std::vector<Entry>> cache_; mutex mutex_;
std::map<int, Entry> cache_ GUARDED_BY(mutex_);
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache); SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
}; };

View File

@ -63,23 +63,8 @@ port::StatusOr<StreamExecutor*> HostPlatform::ExecutorForDeviceWithPluginConfig(
port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor( port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor(
const StreamExecutorConfig& config) { const StreamExecutorConfig& config) {
mutex_lock lock(executors_mutex_); return executor_cache_.GetOrCreate(
config, [&]() { return GetUncachedExecutor(config); });
port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
if (status.ok()) {
return status.ValueOrDie();
}
port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
GetUncachedExecutor(config);
if (!executor.ok()) {
return executor.status();
}
StreamExecutor* naked_executor = executor.ValueOrDie().get();
SE_RETURN_IF_ERROR(
executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
return naked_executor;
} }
port::StatusOr<std::unique_ptr<StreamExecutor>> port::StatusOr<std::unique_ptr<StreamExecutor>>

View File

@ -72,9 +72,6 @@ class HostPlatform : public Platform {
// This platform's name. // This platform's name.
string name_; string name_;
// mutex that guards the ordinal-to-executor map.
mutable mutex executors_mutex_;
// Cache of created StreamExecutors. // Cache of created StreamExecutors.
ExecutorCache executor_cache_; ExecutorCache executor_cache_;