81 lines
3.1 KiB
C++
81 lines
3.1 KiB
C++
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#ifndef TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
|
|
#define TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
|
|
|
|
#include <functional>
|
|
#include <map>
|
|
|
|
#include "absl/synchronization/mutex.h"
|
|
#include "tensorflow/stream_executor/lib/status.h"
|
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
|
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
|
|
|
|
namespace stream_executor {
|
|
|
|
// Utility class to allow Platform objects to manage cached StreamExecutors.
|
|
// Thread-safe.
|
|
class ExecutorCache {
|
|
public:
|
|
ExecutorCache() {}
|
|
|
|
// Looks up 'config' in the cache. Returns a pointer to the existing executor,
|
|
// if already present, or creates it using 'factory', if it does not.
|
|
// Factories may be executed concurrently for different device ordinals.
|
|
typedef port::StatusOr<std::unique_ptr<StreamExecutor>> ExecutorFactory();
|
|
port::StatusOr<StreamExecutor*> GetOrCreate(
|
|
const StreamExecutorConfig& config,
|
|
const std::function<ExecutorFactory>& factory);
|
|
|
|
// Returns a pointer to the described executor (if one with a matching config
|
|
// has been created), or a NOT_FOUND status.
|
|
port::StatusOr<StreamExecutor*> Get(const StreamExecutorConfig& config);
|
|
|
|
// Destroys all Executors and clears the cache.
|
|
// Performs no synchronization with the executors - undefined behavior may
|
|
// occur if any executors are active!
|
|
void DestroyAllExecutors();
|
|
|
|
private:
|
|
// Each Entry contains zero or more cached executors for a device ordinal.
|
|
struct Entry {
|
|
~Entry();
|
|
|
|
// Mutex that guards the contents of each entry. The 'mutex_' of the
|
|
// ExecutorCache class protects both the 'cache_' and the existence of each
|
|
// Entry, but not the Entry's contents. 'configurations_mutex' protects the
|
|
// contents of the entry after 'mutex_' has been dropped.
|
|
absl::Mutex configurations_mutex;
|
|
|
|
// Vector of cached {config, executor} pairs.
|
|
std::vector<
|
|
std::pair<StreamExecutorConfig, std::unique_ptr<StreamExecutor>>>
|
|
configurations TF_GUARDED_BY(configurations_mutex);
|
|
};
|
|
|
|
// Maps ordinal number to a list of cached executors for that ordinal.
|
|
// We key off of ordinal (instead of just looking up all fields in the
|
|
// StreamExecutorConfig) for a slight improvement in lookup time.
|
|
absl::Mutex mutex_;
|
|
std::map<int, Entry> cache_ TF_GUARDED_BY(mutex_);
|
|
|
|
SE_DISALLOW_COPY_AND_ASSIGN(ExecutorCache);
|
|
};
|
|
|
|
} // namespace stream_executor
|
|
|
|
#endif // TENSORFLOW_STREAM_EXECUTOR_EXECUTOR_CACHE_H_
|