[XLA] Add platform name to custom-call target registry.
Remove the assumption that all entries in the registry are for the CPU platform. PiperOrigin-RevId: 247542177
This commit is contained in:
parent
92cec71856
commit
127471f213
@ -101,8 +101,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
||||
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
||||
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
||||
}
|
||||
CustomCallTargetRegistry::Global()->Register(fn_name,
|
||||
static_cast<void*>(capsule));
|
||||
CustomCallTargetRegistry::Global()->Register(
|
||||
fn_name, static_cast<void*>(capsule), "Host");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -147,9 +147,10 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
|
||||
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
|
||||
// registered name may not.
|
||||
std::string stripped_name(name.begin() + 1, name.end());
|
||||
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name);
|
||||
func_addr =
|
||||
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
|
||||
} else {
|
||||
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name);
|
||||
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
|
||||
}
|
||||
|
||||
if (func_addr == nullptr) {
|
||||
@ -219,7 +220,7 @@ bool RegisterKnownJITSymbols() {
|
||||
auto* function_address = \
|
||||
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
|
||||
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
|
||||
function_address); \
|
||||
function_address, "Host"); \
|
||||
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
|
||||
"__xla_cpu_runtime_" #base_name); \
|
||||
} while (false)
|
||||
@ -250,8 +251,10 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
|
||||
|
||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
|
||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
|
||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
|
||||
"Host");
|
||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
|
||||
"Host");
|
||||
|
||||
#undef REGISTER_CPU_RUNTIME_SYMBOL
|
||||
|
||||
@ -259,11 +262,12 @@ bool RegisterKnownJITSymbols() {
|
||||
// Unfortunately the double versions are overloaded on some systems, e.g.
|
||||
// Mac so we need an explicit cast. This requires passing the function signature
|
||||
// for that case.
|
||||
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
|
||||
do { \
|
||||
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
|
||||
registry->Register( \
|
||||
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
|
||||
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
|
||||
do { \
|
||||
registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \
|
||||
registry->Register(#name, \
|
||||
reinterpret_cast<void*>(static_cast<double_sig>(name)), \
|
||||
"Host"); \
|
||||
} while (false)
|
||||
|
||||
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
|
||||
@ -321,8 +325,9 @@ bool RegisterKnownJITSymbols() {
|
||||
#ifdef __APPLE__
|
||||
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
|
||||
registry->Register("__sincosf_stret",
|
||||
reinterpret_cast<void*>(__sincosf_stret));
|
||||
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret));
|
||||
reinterpret_cast<void*>(__sincosf_stret), "Host");
|
||||
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
|
||||
"Host");
|
||||
#else
|
||||
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
|
||||
#endif
|
||||
@ -335,19 +340,19 @@ bool RegisterKnownJITSymbols() {
|
||||
|
||||
#undef REGISTER_LIBM_SYMBOL
|
||||
|
||||
registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
|
||||
registry->Register("memmove", reinterpret_cast<void*>(memmove));
|
||||
registry->Register("memset", reinterpret_cast<void*>(memset));
|
||||
registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
|
||||
registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
|
||||
registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
|
||||
|
||||
#ifdef __APPLE__
|
||||
registry->Register("__bzero", reinterpret_cast<void*>(bzero));
|
||||
registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
|
||||
registry->Register("memset_pattern16",
|
||||
reinterpret_cast<void*>(memset_pattern16));
|
||||
reinterpret_cast<void*>(memset_pattern16), "Host");
|
||||
#endif
|
||||
|
||||
#ifdef MEMORY_SANITIZER
|
||||
registry->Register("__msan_unpoison",
|
||||
reinterpret_cast<void*>(__msan_unpoison));
|
||||
reinterpret_cast<void*>(__msan_unpoison), "Host");
|
||||
#endif
|
||||
|
||||
return true;
|
||||
|
@ -23,14 +23,16 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
|
||||
}
|
||||
|
||||
void CustomCallTargetRegistry::Register(const std::string& symbol,
|
||||
void* address) {
|
||||
void* address,
|
||||
const std::string& platform) {
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
registered_symbols_[symbol] = address;
|
||||
registered_symbols_[std::make_pair(symbol, platform)] = address;
|
||||
}
|
||||
|
||||
void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const {
|
||||
void* CustomCallTargetRegistry::Lookup(const std::string& symbol,
|
||||
const std::string& platform) const {
|
||||
std::lock_guard<std::mutex> lock(mu_);
|
||||
auto it = registered_symbols_.find(symbol);
|
||||
auto it = registered_symbols_.find(std::make_pair(symbol, platform));
|
||||
return it == registered_symbols_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
|
@ -19,19 +19,20 @@ limitations under the License.
|
||||
// For this reason, we avoid relying on TensorFlow and instead only use the
|
||||
// standard C++ library.
|
||||
|
||||
#include <map>
|
||||
#include <mutex> // NOLINT
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace xla {
|
||||
|
||||
// The CPU JIT compiler uses this registry to resolve symbolic CustomCall
|
||||
// targets; so when using the CPU JIT, CustomCall targets need to be registered
|
||||
// here with the symbol name used in the CustomCall.
|
||||
// XLA JIT compilers use this registry to resolve symbolic CustomCall targets;
|
||||
// so when using XLA as a JIT, CustomCall targets need to be registered here
|
||||
// with the symbol name used in the CustomCall.
|
||||
//
|
||||
// The XLA AOT compiler links using a standard offline linker; so when compiling
|
||||
// in AOT mode, you *also* need to make sure the name of the callee (presumably
|
||||
// implemented in C++) matches up with the symbolic name used in the CustomCall.
|
||||
// The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline
|
||||
// linker; so when compiling in CPU AOT mode, you *also* need to make sure the
|
||||
// name of the callee (presumably implemented in C++) matches up with the
|
||||
// symbolic name used in the CustomCall.
|
||||
//
|
||||
// We maintain the registry in both the JIT and the AOT cases for simplicity,
|
||||
// but we only use it when running in JIT mode.
|
||||
@ -39,32 +40,42 @@ class CustomCallTargetRegistry {
|
||||
public:
|
||||
static CustomCallTargetRegistry* Global();
|
||||
|
||||
void Register(const std::string& symbol, void* address);
|
||||
void* Lookup(const std::string& symbol) const;
|
||||
void Register(const std::string& symbol, void* address,
|
||||
const std::string& platform);
|
||||
void* Lookup(const std::string& symbol, const std::string& platform) const;
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, void*> registered_symbols_;
|
||||
// Maps the pair (symbol, platform) to a C function implementing a custom call
|
||||
// named `symbol` for StreamExecutor platform `platform`.
|
||||
//
|
||||
// Different platforms have different ABIs. TODO(jlebar): Describe them!
|
||||
//
|
||||
// (We std::map rather than std::unordered_map because the STL doesn't provide
|
||||
// a default hasher for pair<string, string>, and we want to avoid pulling in
|
||||
// dependencies that might define this.)
|
||||
std::map<std::pair<std::string, std::string>, void*> registered_symbols_;
|
||||
mutable std::mutex mu_;
|
||||
};
|
||||
|
||||
class RegisterCustomCallTarget {
|
||||
public:
|
||||
explicit RegisterCustomCallTarget(const std::string& name, void* address) {
|
||||
CustomCallTargetRegistry::Global()->Register(name, address);
|
||||
explicit RegisterCustomCallTarget(const std::string& name, void* address,
|
||||
const std::string& platform) {
|
||||
CustomCallTargetRegistry::Global()->Register(name, address, platform);
|
||||
}
|
||||
};
|
||||
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
|
||||
#define XLA_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
|
||||
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
||||
counter) \
|
||||
static ::xla::RegisterCustomCallTarget XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT( \
|
||||
custom_call_target_register, counter)(symbol, \
|
||||
reinterpret_cast<void*>(address))
|
||||
#define XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
||||
platform, counter) \
|
||||
static ::xla::RegisterCustomCallTarget XLA_REGISTER_CUSTOM_CALL_CONCAT( \
|
||||
custom_call_target_register, counter)( \
|
||||
symbol, reinterpret_cast<void*>(address), platform)
|
||||
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
||||
__COUNTER__)
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, "Host", \
|
||||
__COUNTER__)
|
||||
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
|
||||
|
@ -53,7 +53,7 @@ StatusOr<Literal> HandleEvaluatorCustomCall(
|
||||
HloInstruction* custom_call, absl::Span<const Literal*> operands) {
|
||||
// Find the target C function in the global registry.
|
||||
auto* registry = CustomCallTargetRegistry::Global();
|
||||
void* target_fn = registry->Lookup(custom_call->custom_call_target());
|
||||
void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host");
|
||||
if (!target_fn) {
|
||||
return NotFound("Custom call target '%s' was not registered",
|
||||
custom_call->custom_call_target());
|
||||
|
Loading…
x
Reference in New Issue
Block a user