diff --git a/tensorflow/compiler/xla/python/local_client.cc b/tensorflow/compiler/xla/python/local_client.cc index 97ce1ca609a..40d737aaffd 100644 --- a/tensorflow/compiler/xla/python/local_client.cc +++ b/tensorflow/compiler/xla/python/local_client.cc @@ -101,8 +101,8 @@ Status RegisterCpuCustomCallTarget(const std::string& fn_name, "Argument to RegisterCpuCustomCallTargetRegistry was not a " "xla._CPU_CUSTOM_CALL_TARGET capsule."); } - CustomCallTargetRegistry::Global()->Register(fn_name, - static_cast(capsule)); + CustomCallTargetRegistry::Global()->Register( + fn_name, static_cast(capsule), "Host"); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 0ca5a09dcb6..bf55e9e22cf 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -147,9 +147,10 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { // On Mac OS X, 'name' may have a leading underscore prefix, even though the // registered name may not. std::string stripped_name(name.begin() + 1, name.end()); - func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name); + func_addr = + xla::CustomCallTargetRegistry::Global()->Lookup(stripped_name, "Host"); } else { - func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name); + func_addr = xla::CustomCallTargetRegistry::Global()->Lookup(name, "Host"); } if (func_addr == nullptr) { @@ -219,7 +220,7 @@ bool RegisterKnownJITSymbols() { auto* function_address = \ reinterpret_cast(__xla_cpu_runtime_##base_name); \ registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ - function_address); \ + function_address, "Host"); \ CHECK_EQ(absl::string_view(xla::cpu::runtime::k##base_name##SymbolName), \ "__xla_cpu_runtime_" #base_name); \ } while (false) @@ -250,8 +251,10 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(TracingStart); REGISTER_CPU_RUNTIME_SYMBOL(TracingEnd); - registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); - registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); + registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee), + "Host"); + registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee), + "Host"); #undef REGISTER_CPU_RUNTIME_SYMBOL @@ -259,11 +262,12 @@ bool RegisterKnownJITSymbols() { // Unfortunately the double versions are overloaded on some systems, e.g. // Mac so we need an explicit cast. This requires passing the function signature // for that case. -#define REGISTER_LIBM_SYMBOL(name, double_sig) \ - do { \ - registry->Register(#name "f", reinterpret_cast(name##f)); \ - registry->Register( \ - #name, reinterpret_cast(static_cast(name))); \ +#define REGISTER_LIBM_SYMBOL(name, double_sig) \ + do { \ + registry->Register(#name "f", reinterpret_cast(name##f), "Host"); \ + registry->Register(#name, \ + reinterpret_cast(static_cast(name)), \ + "Host"); \ } while (false) REGISTER_LIBM_SYMBOL(acos, double (*)(double)); @@ -321,8 +325,9 @@ bool RegisterKnownJITSymbols() { #ifdef __APPLE__ REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); registry->Register("__sincosf_stret", - reinterpret_cast(__sincosf_stret)); - registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret)); + reinterpret_cast(__sincosf_stret), "Host"); + registry->Register("__sincos_stret", reinterpret_cast(__sincos_stret), + "Host"); #else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); #endif @@ -335,19 +340,19 @@ bool RegisterKnownJITSymbols() { #undef REGISTER_LIBM_SYMBOL - registry->Register("memcpy", reinterpret_cast(memcpy)); - registry->Register("memmove", reinterpret_cast(memmove)); - registry->Register("memset", reinterpret_cast(memset)); + registry->Register("memcpy", reinterpret_cast(memcpy), "Host"); + registry->Register("memmove", reinterpret_cast(memmove), "Host"); + registry->Register("memset", reinterpret_cast(memset), "Host"); #ifdef __APPLE__ - registry->Register("__bzero", reinterpret_cast(bzero)); + registry->Register("__bzero", reinterpret_cast(bzero), "Host"); registry->Register("memset_pattern16", - reinterpret_cast(memset_pattern16)); + reinterpret_cast(memset_pattern16), "Host"); #endif #ifdef MEMORY_SANITIZER registry->Register("__msan_unpoison", - reinterpret_cast(__msan_unpoison)); + reinterpret_cast(__msan_unpoison), "Host"); #endif return true; diff --git a/tensorflow/compiler/xla/service/custom_call_target_registry.cc b/tensorflow/compiler/xla/service/custom_call_target_registry.cc index 49875d3c8d1..e6a70211f25 100644 --- a/tensorflow/compiler/xla/service/custom_call_target_registry.cc +++ b/tensorflow/compiler/xla/service/custom_call_target_registry.cc @@ -23,14 +23,16 @@ CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { } void CustomCallTargetRegistry::Register(const std::string& symbol, - void* address) { + void* address, + const std::string& platform) { std::lock_guard lock(mu_); - registered_symbols_[symbol] = address; + registered_symbols_[std::make_pair(symbol, platform)] = address; } -void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const { +void* CustomCallTargetRegistry::Lookup(const std::string& symbol, + const std::string& platform) const { std::lock_guard lock(mu_); - auto it = registered_symbols_.find(symbol); + auto it = registered_symbols_.find(std::make_pair(symbol, platform)); return it == registered_symbols_.end() ? nullptr : it->second; } diff --git a/tensorflow/compiler/xla/service/custom_call_target_registry.h b/tensorflow/compiler/xla/service/custom_call_target_registry.h index 7832e06ece0..b1c9494afed 100644 --- a/tensorflow/compiler/xla/service/custom_call_target_registry.h +++ b/tensorflow/compiler/xla/service/custom_call_target_registry.h @@ -19,19 +19,20 @@ limitations under the License. // For this reason, we avoid relying on TensorFlow and instead only use the // standard C++ library. +#include #include // NOLINT #include -#include namespace xla { -// The CPU JIT compiler uses this registry to resolve symbolic CustomCall -// targets; so when using the CPU JIT, CustomCall targets need to be registered -// here with the symbol name used in the CustomCall. +// XLA JIT compilers use this registry to resolve symbolic CustomCall targets; +// so when using XLA as a JIT, CustomCall targets need to be registered here +// with the symbol name used in the CustomCall. // -// The XLA AOT compiler links using a standard offline linker; so when compiling -// in AOT mode, you *also* need to make sure the name of the callee (presumably -// implemented in C++) matches up with the symbolic name used in the CustomCall. +// The XLA:CPU ahead-of-time (AOT) compiler links using a standard offline +// linker; so when compiling in CPU AOT mode, you *also* need to make sure the +// name of the callee (presumably implemented in C++) matches up with the +// symbolic name used in the CustomCall. // // We maintain the registry in both the JIT and the AOT cases for simplicity, // but we only use it when running in JIT mode. @@ -39,32 +40,42 @@ class CustomCallTargetRegistry { public: static CustomCallTargetRegistry* Global(); - void Register(const std::string& symbol, void* address); - void* Lookup(const std::string& symbol) const; + void Register(const std::string& symbol, void* address, + const std::string& platform); + void* Lookup(const std::string& symbol, const std::string& platform) const; private: - std::unordered_map registered_symbols_; + // Maps the pair (symbol, platform) to a C function implementing a custom call + // named `symbol` for StreamExecutor platform `platform`. + // + // Different platforms have different ABIs. TODO(jlebar): Describe them! + // + // (We std::map rather than std::unordered_map because the STL doesn't provide + // a default hasher for pair, and we want to avoid pulling in + // dependencies that might define this.) + std::map, void*> registered_symbols_; mutable std::mutex mu_; }; class RegisterCustomCallTarget { public: - explicit RegisterCustomCallTarget(const std::string& name, void* address) { - CustomCallTargetRegistry::Global()->Register(name, address); + explicit RegisterCustomCallTarget(const std::string& name, void* address, + const std::string& platform) { + CustomCallTargetRegistry::Global()->Register(name, address, platform); } }; -#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b +#define XLA_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b -#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ - counter) \ - static ::xla::RegisterCustomCallTarget XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT( \ - custom_call_target_register, counter)(symbol, \ - reinterpret_cast(address)) +#define XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ + platform, counter) \ + static ::xla::RegisterCustomCallTarget XLA_REGISTER_CUSTOM_CALL_CONCAT( \ + custom_call_target_register, counter)( \ + symbol, reinterpret_cast(address), platform) -#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ - XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ - __COUNTER__) +#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ + XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, "Host", \ + __COUNTER__) #define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 12807cbbe3a..80a3ebccff1 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -53,7 +53,7 @@ StatusOr HandleEvaluatorCustomCall( HloInstruction* custom_call, absl::Span operands) { // Find the target C function in the global registry. auto* registry = CustomCallTargetRegistry::Global(); - void* target_fn = registry->Lookup(custom_call->custom_call_target()); + void* target_fn = registry->Lookup(custom_call->custom_call_target(), "Host"); if (!target_fn) { return NotFound("Custom call target '%s' was not registered", custom_call->custom_call_target());