Change 109695551 Update FAQ Change 109694725 Add a gradient for resize_bilinear op. Change 109694505 Don't mention variables module in docs variables.Variable should be tf.Variable. Change 109658848 Adding an option to create a new thread-pool for each session. Change 109640570 Take the snapshot of stream-executor. + Expose an interface for scratch space allocation in the interface. Change 109638559 Let image_summary accept uint8 input This allows users to do their own normalization / scaling if the default (very weird) behavior of image_summary is undesired. This required a slight tweak to fake_input.cc to make polymorphically typed fake inputs infer if their type attr is not set but has a default. Unfortunately, adding a second valid type to image_summary *disables* automatic implicit conversion from np.float64 to tf.float32, so this change is slightly backwards incompatible. Change 109636969 Add serialization operations for SparseTensor. Change 109636644 Update generated Op docs. Change 109634899 TensorFlow: add a markdown file for producing release notes for our releases. Seed with 0.5.0 with a boring but accurate description. Change 109634502 Let histogram_summary take any realnumbertype It used to take only floats, not it understands ints. Change 109634434 TensorFlow: update locations where we mention python 3 support, update them to current truth. Change 109632108 Move HSV <> RGB conversions, grayscale conversions, and adjust_* ops back to tensorflow - make GPU-capable version of RGBToHSV and HSVToRGB, allows only float input/output - change docs to reflect new size constraints - change HSV format to be [0,1] for all components - add automatic dtype conversion for all adjust_* and grayscale conversion ops - fix up docs Change 109631077 Improve optimizer exceptions 1. grads_and_vars is now a tuple, so must be wrapped when passed to format. 2. Use '%r' instead of '%s' for dtype formatting Base CL: 109697989
245 lines
11 KiB
C++
245 lines
11 KiB
C++
/* Copyright 2015 Google Inc. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/stream_executor/plugin_registry.h"
|
|
|
|
#include "tensorflow/stream_executor/lib/error.h"
|
|
#include "tensorflow/stream_executor/lib/stringprintf.h"
|
|
#include "tensorflow/stream_executor/multi_platform_manager.h"
|
|
|
|
namespace perftools {
|
|
namespace gputools {
|
|
|
|
const PluginId kNullPlugin = nullptr;
|
|
|
|
// Returns the string representation of the specified PluginKind.
|
|
string PluginKindString(PluginKind plugin_kind) {
|
|
switch (plugin_kind) {
|
|
case PluginKind::kBlas:
|
|
return "BLAS";
|
|
case PluginKind::kDnn:
|
|
return "DNN";
|
|
case PluginKind::kFft:
|
|
return "FFT";
|
|
case PluginKind::kRng:
|
|
return "RNG";
|
|
case PluginKind::kInvalid:
|
|
default:
|
|
return "kInvalid";
|
|
}
|
|
}
|
|
|
|
PluginRegistry::DefaultFactories::DefaultFactories() :
|
|
blas(kNullPlugin), dnn(kNullPlugin), fft(kNullPlugin), rng(kNullPlugin) { }
|
|
|
|
/* static */ mutex PluginRegistry::mu_(LINKER_INITIALIZED);
|
|
/* static */ PluginRegistry* PluginRegistry::instance_ = nullptr;
|
|
|
|
PluginRegistry::PluginRegistry() {}
|
|
|
|
/* static */ PluginRegistry* PluginRegistry::Instance() {
|
|
mutex_lock lock{mu_};
|
|
if (instance_ == nullptr) {
|
|
instance_ = new PluginRegistry();
|
|
}
|
|
return instance_;
|
|
}
|
|
|
|
void PluginRegistry::MapPlatformKindToId(PlatformKind platform_kind,
|
|
Platform::Id platform_id) {
|
|
platform_id_by_kind_[platform_kind] = platform_id;
|
|
}
|
|
|
|
template <typename FACTORY_TYPE>
|
|
port::Status PluginRegistry::RegisterFactoryInternal(
|
|
PluginId plugin_id, const string& plugin_name, FACTORY_TYPE factory,
|
|
std::map<PluginId, FACTORY_TYPE>* factories) {
|
|
mutex_lock lock{mu_};
|
|
|
|
if (factories->find(plugin_id) != factories->end()) {
|
|
return port::Status{
|
|
port::error::ALREADY_EXISTS,
|
|
port::Printf("Attempting to register factory for plugin %s when "
|
|
"one has already been registered",
|
|
plugin_name.c_str())};
|
|
}
|
|
|
|
(*factories)[plugin_id] = factory;
|
|
plugin_names_[plugin_id] = plugin_name;
|
|
return port::Status::OK();
|
|
}
|
|
|
|
template <typename FACTORY_TYPE>
|
|
port::StatusOr<FACTORY_TYPE> PluginRegistry::GetFactoryInternal(
|
|
PluginId plugin_id, const std::map<PluginId, FACTORY_TYPE>& factories,
|
|
const std::map<PluginId, FACTORY_TYPE>& generic_factories) const {
|
|
auto iter = factories.find(plugin_id);
|
|
if (iter == factories.end()) {
|
|
iter = generic_factories.find(plugin_id);
|
|
if (iter == generic_factories.end()) {
|
|
return port::Status{
|
|
port::error::NOT_FOUND,
|
|
port::Printf("Plugin ID %p not registered.", plugin_id)};
|
|
}
|
|
}
|
|
|
|
return iter->second;
|
|
}
|
|
|
|
bool PluginRegistry::SetDefaultFactory(Platform::Id platform_id,
|
|
PluginKind plugin_kind,
|
|
PluginId plugin_id) {
|
|
if (!HasFactory(platform_id, plugin_kind, plugin_id)) {
|
|
port::StatusOr<Platform*> status =
|
|
MultiPlatformManager::PlatformWithId(platform_id);
|
|
string platform_name = "<unregistered platform>";
|
|
if (status.ok()) {
|
|
platform_name = status.ValueOrDie()->Name();
|
|
}
|
|
|
|
LOG(ERROR) << "A factory must be registered for a platform before being "
|
|
<< "set as default! "
|
|
<< "Platform name: " << platform_name
|
|
<< ", PluginKind: " << PluginKindString(plugin_kind)
|
|
<< ", PluginId: " << plugin_id;
|
|
return false;
|
|
}
|
|
|
|
switch (plugin_kind) {
|
|
case PluginKind::kBlas:
|
|
default_factories_[platform_id].blas = plugin_id;
|
|
break;
|
|
case PluginKind::kDnn:
|
|
default_factories_[platform_id].dnn = plugin_id;
|
|
break;
|
|
case PluginKind::kFft:
|
|
default_factories_[platform_id].fft = plugin_id;
|
|
break;
|
|
case PluginKind::kRng:
|
|
default_factories_[platform_id].rng = plugin_id;
|
|
break;
|
|
default:
|
|
LOG(ERROR) << "Invalid plugin kind specified: "
|
|
<< static_cast<int>(plugin_kind);
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool PluginRegistry::HasFactory(const PluginFactories& factories,
|
|
PluginKind plugin_kind,
|
|
PluginId plugin_id) const {
|
|
switch (plugin_kind) {
|
|
case PluginKind::kBlas:
|
|
return factories.blas.find(plugin_id) != factories.blas.end();
|
|
case PluginKind::kDnn:
|
|
return factories.dnn.find(plugin_id) != factories.dnn.end();
|
|
case PluginKind::kFft:
|
|
return factories.fft.find(plugin_id) != factories.fft.end();
|
|
case PluginKind::kRng:
|
|
return factories.rng.find(plugin_id) != factories.rng.end();
|
|
default:
|
|
LOG(ERROR) << "Invalid plugin kind specified: "
|
|
<< PluginKindString(plugin_kind);
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool PluginRegistry::HasFactory(Platform::Id platform_id,
|
|
PluginKind plugin_kind,
|
|
PluginId plugin_id) const {
|
|
auto iter = factories_.find(platform_id);
|
|
if (iter != factories_.end()) {
|
|
if (HasFactory(iter->second, plugin_kind, plugin_id)) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
return HasFactory(generic_factories_, plugin_kind, plugin_id);
|
|
}
|
|
|
|
// Explicit instantiations to support types exposed in user/public API.
|
|
#define EMIT_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE, FACTORY_VAR, PLUGIN_STRING) \
|
|
template port::StatusOr<PluginRegistry::FACTORY_TYPE> \
|
|
PluginRegistry::GetFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
|
|
PluginId plugin_id, \
|
|
const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& factories, \
|
|
const std::map<PluginId, PluginRegistry::FACTORY_TYPE>& \
|
|
generic_factories) const; \
|
|
\
|
|
template port::Status \
|
|
PluginRegistry::RegisterFactoryInternal<PluginRegistry::FACTORY_TYPE>( \
|
|
PluginId plugin_id, const string& plugin_name, \
|
|
PluginRegistry::FACTORY_TYPE factory, \
|
|
std::map<PluginId, PluginRegistry::FACTORY_TYPE>* factories); \
|
|
\
|
|
template <> \
|
|
port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
|
|
Platform::Id platform_id, PluginId plugin_id, const string& name, \
|
|
PluginRegistry::FACTORY_TYPE factory) { \
|
|
return RegisterFactoryInternal(plugin_id, name, factory, \
|
|
&factories_[platform_id].FACTORY_VAR); \
|
|
} \
|
|
\
|
|
template <> \
|
|
port::Status PluginRegistry::RegisterFactoryForAllPlatforms< \
|
|
PluginRegistry::FACTORY_TYPE>(PluginId plugin_id, const string& name, \
|
|
PluginRegistry::FACTORY_TYPE factory) { \
|
|
return RegisterFactoryInternal(plugin_id, name, factory, \
|
|
&generic_factories_.FACTORY_VAR); \
|
|
} \
|
|
\
|
|
template <> \
|
|
port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
|
|
Platform::Id platform_id, PluginId plugin_id) { \
|
|
if (plugin_id == PluginConfig::kDefault) { \
|
|
plugin_id = default_factories_[platform_id].FACTORY_VAR; \
|
|
\
|
|
if (plugin_id == kNullPlugin) { \
|
|
return port::Status{port::error::FAILED_PRECONDITION, \
|
|
"No suitable " PLUGIN_STRING \
|
|
" plugin registered. Have you linked in a " \
|
|
PLUGIN_STRING "-providing plugin?"}; \
|
|
} else { \
|
|
VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \
|
|
<< plugin_names_[plugin_id]; \
|
|
} \
|
|
} \
|
|
return GetFactoryInternal(plugin_id, factories_[platform_id].FACTORY_VAR, \
|
|
generic_factories_.FACTORY_VAR); \
|
|
} \
|
|
\
|
|
/* TODO(b/22689637): Also temporary WRT MultiPlatformManager */ \
|
|
template <> \
|
|
port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory( \
|
|
PlatformKind platform_kind, PluginId plugin_id) { \
|
|
auto iter = platform_id_by_kind_.find(platform_kind); \
|
|
if (iter == platform_id_by_kind_.end()) { \
|
|
return port::Status{port::error::FAILED_PRECONDITION, \
|
|
port::Printf("Platform kind %d not registered.", \
|
|
static_cast<int>(platform_kind))}; \
|
|
} \
|
|
return GetFactory<PluginRegistry::FACTORY_TYPE>(iter->second, plugin_id); \
|
|
}
|
|
|
|
EMIT_PLUGIN_SPECIALIZATIONS(BlasFactory, blas, "BLAS");
|
|
EMIT_PLUGIN_SPECIALIZATIONS(DnnFactory, dnn, "DNN");
|
|
EMIT_PLUGIN_SPECIALIZATIONS(FftFactory, fft, "FFT");
|
|
EMIT_PLUGIN_SPECIALIZATIONS(RngFactory, rng, "RNG");
|
|
|
|
} // namespace gputools
|
|
} // namespace perftools
|