From 277e22a01540b76151f5a664c8fd08854663fe27 Mon Sep 17 00:00:00 2001 From: Deven Desai <36858332+deven-amd@users.noreply.github.com> Date: Fri, 8 Jan 2021 03:05:28 -0800 Subject: [PATCH] PR #46222: [ROCm] Updating XLA custom_call_test to enable it for the ROCm platform Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/46222 -------------------------- /cc @chsigg @cheshire @nvining-work Copybara import of the project: -- 823c406a07c9f2644ef82c0407f5f6f3c895428a by Deven Desai : [ROCm] Updating XLA custom_call_test to enable it for the ROCm platform COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/46222 from ROCmSoftwarePlatform:google_upstream_rocm_fix_xla_custom_call_test 823c406a07c9f2644ef82c0407f5f6f3c895428a PiperOrigin-RevId: 350730351 Change-Id: Id64bd074fda2b185e4791c926bd59b944db60a11 --- tensorflow/compiler/xla/service/gpu/BUILD | 6 +- .../xla/service/gpu/custom_call_test.cc | 64 +++++++++++++------ 2 files changed, 47 insertions(+), 23 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 09957450293..82b1e7707c4 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -119,8 +119,6 @@ tf_cc_test( tags = tf_cuda_tests_tags(), deps = [ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep - ] + if_cuda_is_configured([ - "@local_config_cuda//cuda:cuda_headers", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:xla_builder", @@ -129,6 +127,10 @@ tf_cc_test( "//tensorflow/compiler/xla/service:gpu_plugin", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", + ] + if_cuda_is_configured([ + "@local_config_cuda//cuda:cuda_headers", + ]) + if_rocm_is_configured([ + "@local_config_rocm//rocm:rocm_headers", ]), ) diff --git a/tensorflow/compiler/xla/service/gpu/custom_call_test.cc b/tensorflow/compiler/xla/service/gpu/custom_call_test.cc index e28cb662116..38934a09fc3 100644 --- a/tensorflow/compiler/xla/service/gpu/custom_call_test.cc +++ b/tensorflow/compiler/xla/service/gpu/custom_call_test.cc @@ -13,9 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#if GOOGLE_CUDA #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/gpus/cuda/include/driver_types.h" +#define PLATFORM "CUDA" +#elif TENSORFLOW_USE_ROCM +#include "rocm/include/hip/hip_runtime.h" +#define PLATFORM "ROCM" +#endif #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/custom_call_target_registry.h" @@ -23,6 +29,23 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/stream_executor/gpu/gpu_types.h" + +#if GOOGLE_CUDA +#define gpuSuccess cudaSuccess +#define gpuMemcpyAsync cudaMemcpyAsync +#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice +#define gpuMemcpy cudaMemcpy +#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost +#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice +#elif TENSORFLOW_USE_ROCM +#define gpuSuccess hipSuccess +#define gpuMemcpyAsync hipMemcpyAsync +#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define gpuMemcpy hipMemcpy +#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost +#define gpuMemcpyHostToDevice hipMemcpyHostToDevice +#endif namespace xla { namespace { @@ -30,11 +53,11 @@ namespace { class CustomCallTest : public ClientLibraryTestBase {}; bool is_invoked_called = false; -void Callback_IsInvoked(CUstream /*stream*/, void** /*buffers*/, +void Callback_IsInvoked(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/, const char* /*opaque*/, size_t /*opaque_len*/) { is_invoked_called = true; } -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, PLATFORM); TEST_F(CustomCallTest, IsInvoked) { XlaBuilder b(TestName()); @@ -53,16 +76,15 @@ TEST_F(CustomCallTest, UnknownTarget) { /*opaque=*/""); ASSERT_FALSE(Execute(&b, {}).ok()); } - -void Callback_Memcpy(CUstream stream, void** buffers, const char* /*opaque*/, - size_t /*opaque_len*/) { +void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers, + const char* /*opaque*/, size_t /*opaque_len*/) { void* src = buffers[0]; void* dst = buffers[1]; - auto err = cudaMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128, - cudaMemcpyDeviceToDevice, stream); - ASSERT_EQ(err, cudaSuccess); + auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128, + gpuMemcpyDeviceToDevice, stream); + ASSERT_EQ(err, gpuSuccess); } -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM); TEST_F(CustomCallTest, Memcpy) { XlaBuilder b(TestName()); CustomCall(&b, "Callback_Memcpy", @@ -74,12 +96,12 @@ TEST_F(CustomCallTest, Memcpy) { // Check that opaque handles nulls within the string. std::string& kExpectedOpaque = *new std::string("abc\0def", 7); -void Callback_Opaque(CUstream /*stream*/, void** /*buffers*/, +void Callback_Opaque(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/, const char* opaque, size_t opaque_len) { std::string opaque_str(opaque, opaque_len); ASSERT_EQ(opaque_str, kExpectedOpaque); } -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, PLATFORM); TEST_F(CustomCallTest, Opaque) { XlaBuilder b(TestName()); CustomCall(&b, "Callback_Opaque", /*operands=*/{}, @@ -87,7 +109,7 @@ TEST_F(CustomCallTest, Opaque) { TF_ASSERT_OK(Execute(&b, {}).status()); } -void Callback_SubBuffers(CUstream stream, void** buffers, +void Callback_SubBuffers(se::gpu::GpuStreamHandle stream, void** buffers, const char* /*opaque*/, size_t /*opaque_len*/) { // `buffers` is a flat array containing device pointers to the following. // @@ -103,16 +125,16 @@ void Callback_SubBuffers(CUstream stream, void** buffers, // Set output leaf buffers, copying data from the corresponding same-sized // inputs. - cudaMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float), - cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float), - cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float), - cudaMemcpyDeviceToDevice, stream); - cudaMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float), - cudaMemcpyDeviceToDevice, stream); + gpuMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float), + gpuMemcpyDeviceToDevice, stream); + gpuMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float), + gpuMemcpyDeviceToDevice, stream); + gpuMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float), + gpuMemcpyDeviceToDevice, stream); + gpuMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float), + gpuMemcpyDeviceToDevice, stream); } -XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, "CUDA"); +XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, PLATFORM); TEST_F(CustomCallTest, SubBuffers) { XlaBuilder b(TestName()); CustomCall(&b, "Callback_SubBuffers", /*operands=*/