[XLA] Add platform name to custom-call target registry.
Remove the assumption that all entries in the registry are for the CPU platform. PiperOrigin-RevId: 247542177
This commit is contained in:
parent
92cec71856
commit
127471f213
@ -101,8 +101,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
|
|||||||
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
"Argument to RegisterCpuCustomCallTargetRegistry was not a "
|
||||||
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
"xla._CPU_CUSTOM_CALL_TARGET capsule.");
|
||||||
}
|
}
|
||||||
CustomCallTargetRegistry::Global()->Register(fn_name,
|
CustomCallTargetRegistry::Global()->Register(
|
||||||
static_cast<void*>(capsule));
|
fn_name, static_cast<void*>(capsule), "Host");
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,9 +147,10 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
|
|||||||
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
|
// On Mac OS X, 'name' may have a leading underscore prefix, even though the
|
||||||
// registered name may not.
|
// registered name may not.
|
||||||
std::string stripped_name(name.begin() + 1, name.end());
|
std::string stripped_name(name.begin() + 1, name.end());
|
||||||
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name);
|
func_addr =
|
||||||
|
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
|
||||||
} else {
|
} else {
|
||||||
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name);
|
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (func_addr == nullptr) {
|
if (func_addr == nullptr) {
|
||||||
@ -219,7 +220,7 @@ bool RegisterKnownJITSymbols() {
|
|||||||
auto* function_address = \
|
auto* function_address = \
|
||||||
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
|
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
|
||||||
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
|
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
|
||||||
function_address); \
|
function_address, "Host"); \
|
||||||
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
|
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
|
||||||
"__xla_cpu_runtime_" #base_name); \
|
"__xla_cpu_runtime_" #base_name); \
|
||||||
} while (false)
|
} while (false)
|
||||||
@ -250,8 +251,10 @@ bool RegisterKnownJITSymbols() {
|
|||||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
|
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
|
||||||
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
|
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
|
||||||
|
|
||||||
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee));
|
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
|
||||||
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee));
|
"Host");
|
||||||
|
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
|
||||||
|
"Host");
|
||||||
|
|
||||||
#undef REGISTER_CPU_RUNTIME_SYMBOL
|
#undef REGISTER_CPU_RUNTIME_SYMBOL
|
||||||
|
|
||||||
@ -259,11 +262,12 @@ bool RegisterKnownJITSymbols() {
|
|||||||
// Unfortunately the double versions are overloaded on some systems, e.g.
|
// Unfortunately the double versions are overloaded on some systems, e.g.
|
||||||
// Mac so we need an explicit cast. This requires passing the function signature
|
// Mac so we need an explicit cast. This requires passing the function signature
|
||||||
// for that case.
|
// for that case.
|
||||||
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
|
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
|
||||||
do { \
|
do { \
|
||||||
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
|
registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \
|
||||||
registry->Register( \
|
registry->Register(#name, \
|
||||||
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
|
reinterpret_cast<void*>(static_cast<double_sig>(name)), \
|
||||||
|
"Host"); \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
|
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
|
||||||
@ -321,8 +325,9 @@ bool RegisterKnownJITSymbols() {
|
|||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
|
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
|
||||||
registry->Register("__sincosf_stret",
|
registry->Register("__sincosf_stret",
|
||||||
reinterpret_cast<void*>(__sincosf_stret));
|
reinterpret_cast<void*>(__sincosf_stret), "Host");
|
||||||
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret));
|
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
|
||||||
|
"Host");
|
||||||
#else
|
#else
|
||||||
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
|
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
|
||||||
#endif
|
#endif
|
||||||
@ -335,19 +340,19 @@ bool RegisterKnownJITSymbols() {
|
|||||||
|
|
||||||
#undef REGISTER_LIBM_SYMBOL
|
#undef REGISTER_LIBM_SYMBOL
|
||||||
|
|
||||||
registry->Register("memcpy", reinterpret_cast<void*>(memcpy));
|
registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
|
||||||
registry->Register("memmove", reinterpret_cast<void*>(memmove));
|
registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
|
||||||
registry->Register("memset", reinterpret_cast<void*>(memset));
|
registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
|
||||||
|
|
||||||
#ifdef __APPLE__
|
#ifdef __APPLE__
|
||||||
registry->Register("__bzero", reinterpret_cast<void*>(bzero));
|
registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
|
||||||
registry->Register("memset_pattern16",
|
registry->Register("memset_pattern16",
|
||||||
reinterpret_cast<void*>(memset_pattern16));
|
reinterpret_cast<void*>(memset_pattern16), "Host");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef MEMORY_SANITIZER
|
#ifdef MEMORY_SANITIZER
|
||||||
registry->Register("__msan_unpoison",
|
registry->Register("__msan_unpoison",
|
||||||
reinterpret_cast<void*>(__msan_unpoison));
|
reinterpret_cast<void*>(__msan_unpoison), "Host");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
|
@ -23,14 +23,16 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void CustomCallTargetRegistry::Register(const std::string& symbol,
|
void CustomCallTargetRegistry::Register(const std::string& symbol,
|
||||||
void* address) {
|
void* address,
|
||||||
|
const std::string& platform) {
|
||||||
std::lock_guard<std::mutex> lock(mu_);
|
std::lock_guard<std::mutex> lock(mu_);
|
||||||
registered_symbols_[symbol] = address;
|
registered_symbols_[std::make_pair(symbol, platform)] = address;
|
||||||
}
|
}
|
||||||
|
|
||||||
void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const {
|
void* CustomCallTargetRegistry::Lookup(const std::string& symbol,
|
||||||
|
const std::string& platform) const {
|
||||||
std::lock_guard<std::mutex> lock(mu_);
|
std::lock_guard<std::mutex> lock(mu_);
|
||||||
auto it = registered_symbols_.find(symbol);
|
auto it = registered_symbols_.find(std::make_pair(symbol, platform));
|
||||||
return it == registered_symbols_.end() ? nullptr : it->second;
|
return it == registered_symbols_.end() ? nullptr : it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -19,19 +19,20 @@ limitations under the License.
|
|||||||
// For this reason, we avoid relying on TensorFlow and instead only use the
|
// For this reason, we avoid relying on TensorFlow and instead only use the
|
||||||
// standard C++ library.
|
// standard C++ library.
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include <mutex> // NOLINT
|
#include <mutex> // NOLINT
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_map>
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
// The CPU JIT compiler uses this registry to resolve symbolic CustomCall
|
// XLA JIT compilers use this registry to resolve symbolic CustomCall targets;
|
||||||
// targets; so when using the CPU JIT, CustomCall targets need to be registered
|
// so when using XLA as a JIT, CustomCall targets need to be registered here
|
||||||
// here with the symbol name used in the CustomCall.
|
// with the symbol name used in the CustomCall.
|
||||||
//
|
//
|
||||||
// The XLA AOT compiler links using a standard offline linker; so when compiling
|
// The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline
|
||||||
// in AOT mode, you *also* need to make sure the name of the callee (presumably
|
// linker; so when compiling in CPU AOT mode, you *also* need to make sure the
|
||||||
// implemented in C++) matches up with the symbolic name used in the CustomCall.
|
// name of the callee (presumably implemented in C++) matches up with the
|
||||||
|
// symbolic name used in the CustomCall.
|
||||||
//
|
//
|
||||||
// We maintain the registry in both the JIT and the AOT cases for simplicity,
|
// We maintain the registry in both the JIT and the AOT cases for simplicity,
|
||||||
// but we only use it when running in JIT mode.
|
// but we only use it when running in JIT mode.
|
||||||
@ -39,32 +40,42 @@ class CustomCallTargetRegistry {
|
|||||||
public:
|
public:
|
||||||
static CustomCallTargetRegistry* Global();
|
static CustomCallTargetRegistry* Global();
|
||||||
|
|
||||||
void Register(const std::string& symbol, void* address);
|
void Register(const std::string& symbol, void* address,
|
||||||
void* Lookup(const std::string& symbol) const;
|
const std::string& platform);
|
||||||
|
void* Lookup(const std::string& symbol, const std::string& platform) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<std::string, void*> registered_symbols_;
|
// Maps the pair (symbol, platform) to a C function implementing a custom call
|
||||||
|
// named `symbol` for StreamExecutor platform `platform`.
|
||||||
|
//
|
||||||
|
// Different platforms have different ABIs. TODO(jlebar): Describe them!
|
||||||
|
//
|
||||||
|
// (We std::map rather than std::unordered_map because the STL doesn't provide
|
||||||
|
// a default hasher for pair<string, string>, and we want to avoid pulling in
|
||||||
|
// dependencies that might define this.)
|
||||||
|
std::map<std::pair<std::string, std::string>, void*> registered_symbols_;
|
||||||
mutable std::mutex mu_;
|
mutable std::mutex mu_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RegisterCustomCallTarget {
|
class RegisterCustomCallTarget {
|
||||||
public:
|
public:
|
||||||
explicit RegisterCustomCallTarget(const std::string& name, void* address) {
|
explicit RegisterCustomCallTarget(const std::string& name, void* address,
|
||||||
CustomCallTargetRegistry::Global()->Register(name, address);
|
const std::string& platform) {
|
||||||
|
CustomCallTargetRegistry::Global()->Register(name, address, platform);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
|
#define XLA_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
|
||||||
|
|
||||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
#define XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
||||||
counter) \
|
platform, counter) \
|
||||||
static ::xla::RegisterCustomCallTarget XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT( \
|
static ::xla::RegisterCustomCallTarget XLA_REGISTER_CUSTOM_CALL_CONCAT( \
|
||||||
custom_call_target_register, counter)(symbol, \
|
custom_call_target_register, counter)( \
|
||||||
reinterpret_cast<void*>(address))
|
symbol, reinterpret_cast<void*>(address), platform)
|
||||||
|
|
||||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
|
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
|
||||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, "Host", \
|
||||||
__COUNTER__)
|
__COUNTER__)
|
||||||
|
|
||||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \
|
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \
|
||||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
|
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
|
||||||
|
@ -53,7 +53,7 @@ StatusOr<Literal> HandleEvaluatorCustomCall(
|
|||||||
HloInstruction* custom_call, absl::Span<const Literal*> operands) {
|
HloInstruction* custom_call, absl::Span<const Literal*> operands) {
|
||||||
// Find the target C function in the global registry.
|
// Find the target C function in the global registry.
|
||||||
auto* registry = CustomCallTargetRegistry::Global();
|
auto* registry = CustomCallTargetRegistry::Global();
|
||||||
void* target_fn = registry->Lookup(custom_call->custom_call_target());
|
void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host");
|
||||||
if (!target_fn) {
|
if (!target_fn) {
|
||||||
return NotFound("Custom call target '%s' was not registered",
|
return NotFound("Custom call target '%s' was not registered",
|
||||||
custom_call->custom_call_target());
|
custom_call->custom_call_target());
|
||||||
|
Loading…
x
Reference in New Issue
Block a user