[XLA:CPU] Prefix XLA:CPU custom-call macros with "XLA_CPU_".

In preparation for adding XLA:GPU custom-call support.

It's not possible to write a C++ function that implements a custom-call for
both CPU and GPU; for one thing, the GPU version needs to take a CUDA stream.
So it wouldn't make sense to use one macro for both CPU and GPU.

PiperOrigin-RevId: 247489592
This commit is contained in:
Justin Lebar 2019-05-09 13:51:36 -07:00 committed by TensorFlower Gardener
parent 259fe4f80a
commit c182167d3f
6 changed files with 21 additions and 17 deletions

View File

@ -46,4 +46,4 @@ extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_1d_xla_impl(out, data);
}
REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl);

View File

@ -51,4 +51,4 @@ extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) {
tensorflow::argmax_float_2d_xla_impl(out, data);
}
REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl);

View File

@ -55,18 +55,21 @@ class RegisterCustomCallTarget {
}
};
#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \
static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \
custom_call_target_register, counter)(symbol, \
reinterpret_cast<void*>(address))
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
counter) \
static ::xla::cpu::RegisterCustomCallTarget \
XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(custom_call_target_register, \
counter)( \
symbol, reinterpret_cast<void*>(address))
#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__)
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \
__COUNTER__)
#define REGISTER_CUSTOM_CALL_TARGET(function) \
REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function)
} // namespace cpu
} // namespace xla

View File

@ -155,7 +155,8 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) {
LOG(ERROR)
<< "Unable to resolve runtime symbol: `" << name
<< "'. Hint: if the symbol a custom call target, make sure you've "
"registered it with the JIT using REGISTER_CUSTOM_CALL_TARGET.";
"registered it with the JIT using "
"XLA_CPU_REGISTER_CUSTOM_CALL_TARGET.";
return nullptr;
}
llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast<uint64_t>(func_addr),

View File

@ -29,7 +29,7 @@ class AliasAnalysisTest : public CpuCodegenTest {};
void FakeCustomCallTarget(float* out, float** in) {}
REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget);
TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) {
const char* hlo_string = R"(

View File

@ -64,10 +64,10 @@ void F32TupleSwap(float** out, float** in) {
} // namespace
REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
namespace xla {
namespace {