PR #46222: [ROCm] Updating XLA custom_call_test to enable it for the ROCm platform

Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/46222

--------------------------

/cc @chsigg @cheshire @nvining-work
Copybara import of the project:

--
823c406a07c9f2644ef82c0407f5f6f3c895428a by Deven Desai <deven.desai.amd@gmail.com>:

[ROCm] Updating XLA custom_call_test to enable it for the ROCm platform

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/46222 from ROCmSoftwarePlatform:google_upstream_rocm_fix_xla_custom_call_test 823c406a07c9f2644ef82c0407f5f6f3c895428a
PiperOrigin-RevId: 350730351
Change-Id: Id64bd074fda2b185e4791c926bd59b944db60a11
This commit is contained in:
Deven Desai 2021-01-08 03:05:28 -08:00 committed by TensorFlower Gardener
parent cf603aa4f9
commit 277e22a015
2 changed files with 47 additions and 23 deletions

View File

@ -119,8 +119,6 @@ tf_cc_test(
tags = tf_cuda_tests_tags(),
deps = [
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:xla_builder",
@ -129,6 +127,10 @@ tf_cc_test(
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/core:test",
] + if_cuda_is_configured([
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured([
"@local_config_rocm//rocm:rocm_headers",
]),
)

View File

@ -13,9 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/driver_types.h"
#define PLATFORM "CUDA"
#elif TENSORFLOW_USE_ROCM
#include "rocm/include/hip/hip_runtime.h"
#define PLATFORM "ROCM"
#endif
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
@ -23,6 +29,23 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/stream_executor/gpu/gpu_types.h"
#if GOOGLE_CUDA
#define gpuSuccess cudaSuccess
#define gpuMemcpyAsync cudaMemcpyAsync
#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
#define gpuMemcpy cudaMemcpy
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
#elif TENSORFLOW_USE_ROCM
#define gpuSuccess hipSuccess
#define gpuMemcpyAsync hipMemcpyAsync
#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define gpuMemcpy hipMemcpy
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
#endif
namespace xla {
namespace {
@ -30,11 +53,11 @@ namespace {
class CustomCallTest : public ClientLibraryTestBase {};
bool is_invoked_called = false;
void Callback_IsInvoked(CUstream /*stream*/, void** /*buffers*/,
void Callback_IsInvoked(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
const char* /*opaque*/, size_t /*opaque_len*/) {
is_invoked_called = true;
}
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, PLATFORM);
TEST_F(CustomCallTest, IsInvoked) {
XlaBuilder b(TestName());
@ -53,16 +76,15 @@ TEST_F(CustomCallTest, UnknownTarget) {
/*opaque=*/"");
ASSERT_FALSE(Execute(&b, {}).ok());
}
void Callback_Memcpy(CUstream stream, void** buffers, const char* /*opaque*/,
size_t /*opaque_len*/) {
void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers,
const char* /*opaque*/, size_t /*opaque_len*/) {
void* src = buffers[0];
void* dst = buffers[1];
auto err = cudaMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
cudaMemcpyDeviceToDevice, stream);
ASSERT_EQ(err, cudaSuccess);
auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
gpuMemcpyDeviceToDevice, stream);
ASSERT_EQ(err, gpuSuccess);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM);
TEST_F(CustomCallTest, Memcpy) {
XlaBuilder b(TestName());
CustomCall(&b, "Callback_Memcpy",
@ -74,12 +96,12 @@ TEST_F(CustomCallTest, Memcpy) {
// Check that opaque handles nulls within the string.
std::string& kExpectedOpaque = *new std::string("abc\0def", 7);
void Callback_Opaque(CUstream /*stream*/, void** /*buffers*/,
void Callback_Opaque(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
const char* opaque, size_t opaque_len) {
std::string opaque_str(opaque, opaque_len);
ASSERT_EQ(opaque_str, kExpectedOpaque);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, PLATFORM);
TEST_F(CustomCallTest, Opaque) {
XlaBuilder b(TestName());
CustomCall(&b, "Callback_Opaque", /*operands=*/{},
@ -87,7 +109,7 @@ TEST_F(CustomCallTest, Opaque) {
TF_ASSERT_OK(Execute(&b, {}).status());
}
void Callback_SubBuffers(CUstream stream, void** buffers,
void Callback_SubBuffers(se::gpu::GpuStreamHandle stream, void** buffers,
const char* /*opaque*/, size_t /*opaque_len*/) {
// `buffers` is a flat array containing device pointers to the following.
//
@ -103,16 +125,16 @@ void Callback_SubBuffers(CUstream stream, void** buffers,
// Set output leaf buffers, copying data from the corresponding same-sized
// inputs.
cudaMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
cudaMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
cudaMemcpyDeviceToDevice, stream);
gpuMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
gpuMemcpyDeviceToDevice, stream);
gpuMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
gpuMemcpyDeviceToDevice, stream);
gpuMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
gpuMemcpyDeviceToDevice, stream);
gpuMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
gpuMemcpyDeviceToDevice, stream);
}
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, "CUDA");
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, PLATFORM);
TEST_F(CustomCallTest, SubBuffers) {
XlaBuilder b(TestName());
CustomCall(&b, "Callback_SubBuffers", /*operands=*/