Enable XLA custom_call_thunk on the ROCm platform
based on the contribution for the same by @inailuig in the following commit
44d3a233c6
This commit is contained in:
parent
4a227e9fc5
commit
06a5f454b0
@ -115,7 +115,7 @@ cc_library(
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_call_test",
|
||||
srcs = if_cuda_is_configured(["custom_call_test.cc"]),
|
||||
srcs = if_cuda_or_rocm(["custom_call_test.cc"]),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
@ -720,6 +720,8 @@ cc_library(
|
||||
] + if_cuda_is_configured([
|
||||
"cholesky_thunk.cc",
|
||||
"custom_call_thunk.cc",
|
||||
]) + if_rocm_is_configured([
|
||||
"custom_call_thunk.cc",
|
||||
]),
|
||||
hdrs = [
|
||||
"collective_permute_thunk.h",
|
||||
@ -744,6 +746,8 @@ cc_library(
|
||||
] + if_cuda_is_configured([
|
||||
"cholesky_thunk.h",
|
||||
"custom_call_thunk.h",
|
||||
]) + if_rocm_is_configured([
|
||||
"custom_call_thunk.h",
|
||||
]),
|
||||
deps = [
|
||||
":backend_configs_cc",
|
||||
|
@ -33,6 +33,10 @@ limitations under the License.
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
|
||||
#endif
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
|
||||
#endif
|
||||
|
||||
@ -374,7 +378,10 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
|
||||
custom_call->custom_call_target(), std::string(platform_name()))) {
|
||||
auto get_slices_for_instr = [&](const HloInstruction* instr) {
|
||||
|
Loading…
Reference in New Issue
Block a user