[XLA] Add platform name to custom-call target registry.

Remove the assumption that all entries in the registry are for the CPU
platform.

PiperOrigin-RevId: 247542177
This commit is contained in:
Justin Lebar 2019-05-09 19:55:49 -07:00 committed by TensorFlower Gardener
parent 92cec71856
commit 127471f213
5 changed files with 64 additions and 46 deletions

View File

@ -101,8 +101,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
}
CustomCallTargetRegistry::Global()->Register(fn_name,
static_cast<void*>(capsule));
CustomCallTargetRegistry::Global()->Register(
fn_name, static_cast<void*>(capsule), "Host");
return Status::OK();
}

View File

@ -147,9 +147,10 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
// registered name may not.
std::string stripped_name(name.begin() + 1, name.end());
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name);
func_addr =
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
} else {
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name);
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
}
if (func_addr == nullptr) {
@ -219,7 +220,7 @@ bool RegisterKnownJITSymbols() {
auto* function_address = \
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
function_address); \
function_address, "Host"); \
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
"__xla_cpu_runtime_" #base_name); \
} while (false)
@ -250,8 +251,10 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
"Host");
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
"Host");
#undef REGISTER_CPU_RUNTIME_SYMBOL
@ -259,11 +262,12 @@ bool RegisterKnownJITSymbols() {
// Unfortunately the double versions are overloaded on some systems, e.g.
// Mac so we need an explicit cast. This requires passing the function signature
// for that case.
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
do { \
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
registry->Register( \
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
do { \
registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \
registry->Register(#name, \
reinterpret_cast<void*>(static_cast<double_sig>(name)), \
"Host"); \
} while (false)
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
@ -321,8 +325,9 @@ bool RegisterKnownJITSymbols() {
#ifdef __APPLE__
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
registry->Register("__sincosf_stret",
reinterpret_cast<void*>(__sincosf_stret));
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret));
reinterpret_cast<void*>(__sincosf_stret), "Host");
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
"Host");
#else
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
#endif
@ -335,19 +340,19 @@ bool RegisterKnownJITSymbols() {
#undef REGISTER_LIBM_SYMBOL
registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
registry->Register("memmove", reinterpret_cast<void*>(memmove));
registry->Register("memset", reinterpret_cast<void*>(memset));
registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
#ifdef __APPLE__
registry->Register("__bzero", reinterpret_cast<void*>(bzero));
registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
registry->Register("memset_pattern16",
reinterpret_cast<void*>(memset_pattern16));
reinterpret_cast<void*>(memset_pattern16), "Host");
#endif
#ifdef MEMORY_SANITIZER
registry->Register("__msan_unpoison",
reinterpret_cast<void*>(__msan_unpoison));
reinterpret_cast<void*>(__msan_unpoison), "Host");
#endif
return true;

View File

@ -23,14 +23,16 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
}
void CustomCallTargetRegistry::Register(const std::string& symbol,
void* address) {
void* address,
const std::string& platform) {
std::lock_guard<std::mutex> lock(mu_);
registered_symbols_[symbol] = address;
registered_symbols_[std::make_pair(symbol, platform)] = address;
}
void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const {
void* CustomCallTargetRegistry::Lookup(const std::string& symbol,
const std::string& platform) const {
std::lock_guard<std::mutex> lock(mu_);
auto it = registered_symbols_.find(symbol);
auto it = registered_symbols_.find(std::make_pair(symbol, platform));
return it == registered_symbols_.end() ? nullptr : it->second;
}

View File

@ -19,19 +19,20 @@ limitations under the License.
// For this reason, we avoid relying on TensorFlow and instead only use the
// standard C++ library.
#include <map>
#include <mutex> // NOLINT
#include <string>
#include <unordered_map>
namespace xla {
// The CPU JIT compiler uses this registry to resolve symbolic CustomCall
// targets; so when using the CPU JIT, CustomCall targets need to be registered
// here with the symbol name used in the CustomCall.
// XLA JIT compilers use this registry to resolve symbolic CustomCall targets;
// so when using XLA as a JIT, CustomCall targets need to be registered here
// with the symbol name used in the CustomCall.
//
// The XLA AOT compiler links using a standard offline linker; so when compiling
// in AOT mode, you *also* need to make sure the name of the callee (presumably
// implemented in C++) matches up with the symbolic name used in the CustomCall.
// The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline
// linker; so when compiling in CPU AOT mode, you *also* need to make sure the
// name of the callee (presumably implemented in C++) matches up with the
// symbolic name used in the CustomCall.
//
// We maintain the registry in both the JIT and the AOT cases for simplicity,
// but we only use it when running in JIT mode.
@ -39,32 +40,42 @@ class CustomCallTargetRegistry {
public:
static CustomCallTargetRegistry* Global();
void Register(const std::string& symbol, void* address);
void* Lookup(const std::string& symbol) const;
void Register(const std::string& symbol, void* address,
const std::string& platform);
void* Lookup(const std::string& symbol, const std::string& platform) const;
private:
std::unordered_map<std::string, void*> registered_symbols_;
// Maps the pair (symbol, platform) to a C function implementing a custom call
// named `symbol` for StreamExecutor platform `platform`.
//
// Different platforms have different ABIs. TODO(jlebar): Describe them!
//
// (We std::map rather than std::unordered_map because the STL doesn't provide
// a default hasher for pair<string, string>, and we want to avoid pulling in
// dependencies that might define this.)
std::map<std::pair<std::string, std::string>, void*> registered_symbols_;
mutable std::mutex mu_;
};
class RegisterCustomCallTarget {
public:
explicit RegisterCustomCallTarget(const std::string& name, void* address) {
CustomCallTargetRegistry::Global()->Register(name, address);
explicit RegisterCustomCallTarget(const std::string& name, void* address,
const std::string& platform) {
CustomCallTargetRegistry::Global()->Register(name, address, platform);
}
};
#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
#define XLA_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
counter) \
static ::xla::RegisterCustomCallTarget XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT( \
custom_call_target_register, counter)(symbol, \
reinterpret_cast<void*>(address))
#define XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
platform, counter) \
static ::xla::RegisterCustomCallTarget XLA_REGISTER_CUSTOM_CALL_CONCAT( \
custom_call_target_register, counter)( \
symbol, reinterpret_cast<void*>(address), platform)
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
__COUNTER__)
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, "Host", \
__COUNTER__)
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)

View File

@ -53,7 +53,7 @@ StatusOr<Literal> HandleEvaluatorCustomCall(
HloInstruction* custom_call, absl::Span<const Literal*> operands) {
// Find the target C function in the global registry.
auto* registry = CustomCallTargetRegistry::Global();
void* target_fn = registry->Lookup(custom_call->custom_call_target());
void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host");
if (!target_fn) {
return NotFound("Custom call target '%s' was not registered",
custom_call->custom_call_target());