diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index 39d96e748b3..1d8d480c63d 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -46,4 +46,4 @@ extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } -REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index 9b83392d8fb..7cae1790510 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -51,4 +51,4 @@ extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } -REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h index 664125ecc95..13cbe8967d6 100644 --- a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h @@ -55,18 +55,21 @@ class RegisterCustomCallTarget { } }; -#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b +#define XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b -#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \ - static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \ - custom_call_target_register, counter)(symbol, \ - reinterpret_cast(address)) +#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ + counter) \ + static ::xla::cpu::RegisterCustomCallTarget \ + XLA_CPU_REGISTER_CUSTOM_CALL_CONCAT(custom_call_target_register, \ + counter)( \ + symbol, reinterpret_cast(address)) -#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ - REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__) +#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, \ + __COUNTER__) -#define REGISTER_CUSTOM_CALL_TARGET(function) \ - REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) +#define XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(function) \ + XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 4ac68691e73..26d47a7b6dd 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -155,7 +155,8 @@ llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { LOG(ERROR) << "Unable to resolve runtime symbol: `" << name << "'. Hint: if the symbol a custom call target, make sure you've " - "registered it with the JIT using REGISTER_CUSTOM_CALL_TARGET."; + "registered it with the JIT using " + "XLA_CPU_REGISTER_CUSTOM_CALL_TARGET."; return nullptr; } llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc index db900856993..6eab8cf8444 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis_test.cc @@ -29,7 +29,7 @@ class AliasAnalysisTest : public CpuCodegenTest {}; void FakeCustomCallTarget(float* out, float** in) {} -REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(FakeCustomCallTarget); TEST_F(AliasAnalysisTest, EmbeddedComputationParamsMayAliasTemps) { const char* hlo_string = R"( diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 4687ed61a7d..90dd725acba 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -64,10 +64,10 @@ void F32TupleSwap(float** out, float** in) { } // namespace -REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); -REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); -REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); -REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); +XLA_CPU_REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap); namespace xla { namespace {