Enable XLA custom_call_thunk on the ROCm platform

based on the contribution for the same by @inailuig in the following commit

44d3a233c6
This commit is contained in:
Deven Desai 2020-12-07 18:53:21 +00:00
parent 4a227e9fc5
commit 06a5f454b0
2 changed files with 12 additions and 1 deletions

View File

@ -115,7 +115,7 @@ cc_library(
tf_cc_test(
name = "custom_call_test",
srcs = if_cuda_is_configured(["custom_call_test.cc"]),
srcs = if_cuda_or_rocm(["custom_call_test.cc"]),
tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
@ -720,6 +720,8 @@ cc_library(
] + if_cuda_is_configured([
"cholesky_thunk.cc",
"custom_call_thunk.cc",
]) + if_rocm_is_configured([
"custom_call_thunk.cc",
]),
hdrs = [
"collective_permute_thunk.h",
@ -744,6 +746,8 @@ cc_library(
] + if_cuda_is_configured([
"cholesky_thunk.h",
"custom_call_thunk.h",
]) + if_rocm_is_configured([
"custom_call_thunk.h",
]),
deps = [
":backend_configs_cc",

View File

@ -33,6 +33,10 @@ limitations under the License.
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
#endif
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
#endif
@ -374,7 +378,10 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
return Status::OK();
}
#endif
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
custom_call->custom_call_target(), std::string(platform_name()))) {
auto get_slices_for_instr = [&](const HloInstruction* instr) {