Enable XLA custom_call_thunk on the ROCm platform
based on the contribution for the same by @inailuig in the following commit
44d3a233c6
This commit is contained in:
parent
4a227e9fc5
commit
06a5f454b0
@ -115,7 +115,7 @@ cc_library(
|
|||||||
|
|
||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "custom_call_test",
|
name = "custom_call_test",
|
||||||
srcs = if_cuda_is_configured(["custom_call_test.cc"]),
|
srcs = if_cuda_or_rocm(["custom_call_test.cc"]),
|
||||||
tags = tf_cuda_tests_tags(),
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||||
@ -720,6 +720,8 @@ cc_library(
|
|||||||
] + if_cuda_is_configured([
|
] + if_cuda_is_configured([
|
||||||
"cholesky_thunk.cc",
|
"cholesky_thunk.cc",
|
||||||
"custom_call_thunk.cc",
|
"custom_call_thunk.cc",
|
||||||
|
]) + if_rocm_is_configured([
|
||||||
|
"custom_call_thunk.cc",
|
||||||
]),
|
]),
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"collective_permute_thunk.h",
|
"collective_permute_thunk.h",
|
||||||
@ -744,6 +746,8 @@ cc_library(
|
|||||||
] + if_cuda_is_configured([
|
] + if_cuda_is_configured([
|
||||||
"cholesky_thunk.h",
|
"cholesky_thunk.h",
|
||||||
"custom_call_thunk.h",
|
"custom_call_thunk.h",
|
||||||
|
]) + if_rocm_is_configured([
|
||||||
|
"custom_call_thunk.h",
|
||||||
]),
|
]),
|
||||||
deps = [
|
deps = [
|
||||||
":backend_configs_cc",
|
":backend_configs_cc",
|
||||||
|
@ -33,6 +33,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA)
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||||
|
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||||
#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
@ -374,7 +378,10 @@ Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
|||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||||
|
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||||
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
|
if (void* call_target = CustomCallTargetRegistry::Global()->Lookup(
|
||||||
custom_call->custom_call_target(), std::string(platform_name()))) {
|
custom_call->custom_call_target(), std::string(platform_name()))) {
|
||||||
auto get_slices_for_instr = [&](const HloInstruction* instr) {
|
auto get_slices_for_instr = [&](const HloInstruction* instr) {
|
||||||
|
Loading…
Reference in New Issue
Block a user