[XLA] Add platform name to custom-call target registry.

Remove the assumption that all entries in the registry are for the CPU
platform.

PiperOrigin-RevId: 247542177
This commit is contained in:
Justin Lebar 2019-05-09 19:55:49 -07:00 committed by TensorFlower Gardener
parent 92cec71856
commit 127471f213
5 changed files with 64 additions and 46 deletions

View File

@ -101,8 +101,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name,
"Argument to RegisterCpuCustomCallTargetRegistry was not a " "Argument to RegisterCpuCustomCallTargetRegistry was not a "
"xla._CPU_CUSTOM_CALL_TARGET capsule."); "xla._CPU_CUSTOM_CALL_TARGET capsule.");
} }
CustomCallTargetRegistry::Global()->Register(fn_name, CustomCallTargetRegistry::Global()->Register(
static_cast<void*>(capsule)); fn_name, static_cast<void*>(capsule), "Host");
return Status::OK(); return Status::OK();
} }

View File

@ -147,9 +147,10 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
// On Mac OS X, 'name' may have a leading underscore prefix, even though the // On Mac OS X, 'name' may have a leading underscore prefix, even though the
// registered name may not. // registered name may not.
std::string stripped_name(name.begin() + 1, name.end()); std::string stripped_name(name.begin() + 1, name.end());
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name); func_addr =
xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host");
} else { } else {
func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name); func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host");
} }
if (func_addr == nullptr) { if (func_addr == nullptr) {
@ -219,7 +220,7 @@ bool RegisterKnownJITSymbols() {
auto* function_address = \ auto* function_address = \
reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \ reinterpret_cast<void*>(__xla_cpu_runtime_##base_name); \
registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \
function_address); \ function_address, "Host"); \
CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \
"__xla_cpu_runtime_" #base_name); \ "__xla_cpu_runtime_" #base_name); \
} while (false) } while (false)
@ -250,8 +251,10 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(TracingStart); REGISTER_CPU_RUNTIME_SYMBOL(TracingStart);
REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd);
registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee)); registry->Register("__gnu_f2h_ieee", reinterpret_cast<void*>(__gnu_f2h_ieee),
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee)); "Host");
registry->Register("__gnu_h2f_ieee", reinterpret_cast<void*>(__gnu_h2f_ieee),
"Host");
#undef REGISTER_CPU_RUNTIME_SYMBOL #undef REGISTER_CPU_RUNTIME_SYMBOL
@ -259,11 +262,12 @@ bool RegisterKnownJITSymbols() {
// Unfortunately the double versions are overloaded on some systems, e.g. // Unfortunately the double versions are overloaded on some systems, e.g.
// Mac so we need an explicit cast. This requires passing the function signature // Mac so we need an explicit cast. This requires passing the function signature
// for that case. // for that case.
#define REGISTER_LIBM_SYMBOL(name, double_sig) \ #define REGISTER_LIBM_SYMBOL(name, double_sig) \
do { \ do { \
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \ registry->Register(#name "f", reinterpret_cast<void*>(name##f), "Host"); \
registry->Register( \ registry->Register(#name, \
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \ reinterpret_cast<void*>(static_cast<double_sig>(name)), \
"Host"); \
} while (false) } while (false)
REGISTER_LIBM_SYMBOL(acos, double (*)(double)); REGISTER_LIBM_SYMBOL(acos, double (*)(double));
@ -321,8 +325,9 @@ bool RegisterKnownJITSymbols() {
#ifdef __APPLE__ #ifdef __APPLE__
REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*));
registry->Register("__sincosf_stret", registry->Register("__sincosf_stret",
reinterpret_cast<void*>(__sincosf_stret)); reinterpret_cast<void*>(__sincosf_stret), "Host");
registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret)); registry->Register("__sincos_stret", reinterpret_cast<void*>(__sincos_stret),
"Host");
#else #else
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
#endif #endif
@ -335,19 +340,19 @@ bool RegisterKnownJITSymbols() {
#undef REGISTER_LIBM_SYMBOL #undef REGISTER_LIBM_SYMBOL
registry->Register("memcpy", reinterpret_cast<void*>(memcpy)); registry->Register("memcpy", reinterpret_cast<void*>(memcpy), "Host");
registry->Register("memmove", reinterpret_cast<void*>(memmove)); registry->Register("memmove", reinterpret_cast<void*>(memmove), "Host");
registry->Register("memset", reinterpret_cast<void*>(memset)); registry->Register("memset", reinterpret_cast<void*>(memset), "Host");
#ifdef __APPLE__ #ifdef __APPLE__
registry->Register("__bzero", reinterpret_cast<void*>(bzero)); registry->Register("__bzero", reinterpret_cast<void*>(bzero), "Host");
registry->Register("memset_pattern16", registry->Register("memset_pattern16",
reinterpret_cast<void*>(memset_pattern16)); reinterpret_cast<void*>(memset_pattern16), "Host");
#endif #endif
#ifdef MEMORY_SANITIZER #ifdef MEMORY_SANITIZER
registry->Register("__msan_unpoison", registry->Register("__msan_unpoison",
reinterpret_cast<void*>(__msan_unpoison)); reinterpret_cast<void*>(__msan_unpoison), "Host");
#endif #endif
return true; return true;

View File

@ -23,14 +23,16 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() {
} }
void CustomCallTargetRegistry::Register(const std::string& symbol, void CustomCallTargetRegistry::Register(const std::string& symbol,
void* address) { void* address,
const std::string& platform) {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
registered_symbols_[symbol] = address; registered_symbols_[std::make_pair(symbol, platform)] = address;
} }
void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const { void* CustomCallTargetRegistry::Lookup(const std::string& symbol,
const std::string& platform) const {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
auto it = registered_symbols_.find(symbol); auto it = registered_symbols_.find(std::make_pair(symbol, platform));
return it == registered_symbols_.end() ? nullptr : it->second; return it == registered_symbols_.end() ? nullptr : it->second;
} }

View File

@ -19,19 +19,20 @@ limitations under the License.
// For this reason, we avoid relying on TensorFlow and instead only use the // For this reason, we avoid relying on TensorFlow and instead only use the
// standard C++ library. // standard C++ library.
#include <map>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <unordered_map>
namespace xla { namespace xla {
// The CPU JIT compiler uses this registry to resolve symbolic CustomCall // XLA JIT compilers use this registry to resolve symbolic CustomCall targets;
// targets; so when using the CPU JIT, CustomCall targets need to be registered // so when using XLA as a JIT, CustomCall targets need to be registered here
// here with the symbol name used in the CustomCall. // with the symbol name used in the CustomCall.
// //
// The XLA AOT compiler links using a standard offline linker; so when compiling // The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline
// in AOT mode, you *also* need to make sure the name of the callee (presumably // linker; so when compiling in CPU AOT mode, you *also* need to make sure the
// implemented in C++) matches up with the symbolic name used in the CustomCall. // name of the callee (presumably implemented in C++) matches up with the
// symbolic name used in the CustomCall.
// //
// We maintain the registry in both the JIT and the AOT cases for simplicity, // We maintain the registry in both the JIT and the AOT cases for simplicity,
// but we only use it when running in JIT mode. // but we only use it when running in JIT mode.
@ -39,32 +40,42 @@ class CustomCallTargetRegistry {
public: public:
static CustomCallTargetRegistry* Global(); static CustomCallTargetRegistry* Global();
void Register(const std::string& symbol, void* address); void Register(const std::string& symbol, void* address,
void* Lookup(const std::string& symbol) const; const std::string& platform);
void* Lookup(const std::string& symbol, const std::string& platform) const;
private: private:
std::unordered_map<std::string, void*> registered_symbols_; // Maps the pair (symbol, platform) to a C function implementing a custom call
// named `symbol` for StreamExecutor platform `platform`.
//
// Different platforms have different ABIs. TODO(jlebar): Describe them!
//
// (We std::map rather than std::unordered_map because the STL doesn't provide
// a default hasher for pair<string, string>, and we want to avoid pulling in
// dependencies that might define this.)
std::map<std::pair<std::string, std::string>, void*> registered_symbols_;
mutable std::mutex mu_; mutable std::mutex mu_;
}; };
class RegisterCustomCallTarget { class RegisterCustomCallTarget {
public: public:
explicit RegisterCustomCallTarget(const std::string& name, void* address) { explicit RegisterCustomCallTarget(const std::string& name, void* address,
CustomCallTargetRegistry::Global()->Register(name, address); const std::string& platform) {
CustomCallTargetRegistry::Global()->Register(name, address, platform);
} }
}; };
#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b #define XLA_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ #define XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
counter) \ platform, counter) \
static ::xla::RegisterCustomCallTarget XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT( \ static ::xla::RegisterCustomCallTarget XLA_REGISTER_CUSTOM_CALL_CONCAT( \
custom_call_target_register, counter)(symbol, \ custom_call_target_register, counter)( \
reinterpret_cast<void*>(address)) symbol, reinterpret_cast<void*>(address), platform)
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ #define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, "Host", \
__COUNTER__) __COUNTER__)
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \ #define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)

View File

@ -53,7 +53,7 @@ StatusOr<Literal> HandleEvaluatorCustomCall(
HloInstruction* custom_call, absl::Span<const Literal*> operands) { HloInstruction* custom_call, absl::Span<const Literal*> operands) {
// Find the target C function in the global registry. // Find the target C function in the global registry.
auto* registry = CustomCallTargetRegistry::Global(); auto* registry = CustomCallTargetRegistry::Global();
void* target_fn = registry->Lookup(custom_call->custom_call_target()); void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host");
if (!target_fn) { if (!target_fn) {
return NotFound("Custom call target '%s' was not registered", return NotFound("Custom call target '%s' was not registered",
custom_call->custom_call_target()); custom_call->custom_call_target());