[XLA:CPU] Prefix XLA:CPU custom-call macros with "XLA_CPU_".
In preparation for adding XLA:GPU custom-call support. It's not possible to write a C++ function that implements a custom-call for both CPU and GPU; for one thing, the GPU version needs to take a CUDA stream. So it wouldn't make sense to use one macro for both CPU and GPU. PiperOrigin-RevId: 247489592
This commit is contained in:
parent
259fe4f80a
commit
c182167d3f
@ -46,4 +46,4 @@ extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_1d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
|
||||
|
@ -51,4 +51,4 @@ extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
|
||||
tensorflow::argmax_float_2d_xla_impl(out, data);
|
||||
}
|
||||
|
||||
REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
|
||||
|
@ -55,18 +55,21 @@ class RegisterCustomCallTarget {
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \
|
||||
static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \
|
||||
custom_call_target_register, counter)(symbol, \
|
||||
reinterpret_cast<void*>(address))
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
||||
counter) \
|
||||
static ::xla::cpu::RegisterCustomCallTarget \
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(custom_call_target_register, \
|
||||
counter)( \
|
||||
symbol, reinterpret_cast<void*>(address))
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
|
||||
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__)
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
|
||||
__COUNTER__)
|
||||
|
||||
#define REGISTER_CUSTOM_CALL_TARGET(function) \
|
||||
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
|
||||
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
@ -155,7 +155,8 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
|
||||
LOG(ERROR)
|
||||
<< "Unable to resolve runtime symbol: `" << name
|
||||
<< "'. Hint: if the symbol a custom call target, make sure you've "
|
||||
"registered it with the JIT using REGISTER_CUSTOM_CALL_TARGET.";
|
||||
"registered it with the JIT using "
|
||||
"XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
|
||||
return nullptr;
|
||||
}
|
||||
llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),
|
||||
|
@ -29,7 +29,7 @@ class AliasAnalysisTest : public CpuCodegenTest {};
|
||||
|
||||
void FakeCustomCallTarget(float* out, float** in) {}
|
||||
|
||||
REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget);
|
||||
|
||||
TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) {
|
||||
const char* hlo_string = R"(
|
||||
|
@ -64,10 +64,10 @@ void F32TupleSwap(float** out, float** in) {
|
||||
|
||||
} // namespace
|
||||
|
||||
REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
|
||||
REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
|
||||
REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
|
||||
REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
|
||||
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
Loading…
Reference in New Issue
Block a user