PR #46222: [ROCm] Updating XLA custom_call_test to enable it for the ROCm platform
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/46222 -------------------------- /cc @chsigg @cheshire @nvining-work Copybara import of the project: -- 823c406a07c9f2644ef82c0407f5f6f3c895428a by Deven Desai <deven.desai.amd@gmail.com>: [ROCm] Updating XLA custom_call_test to enable it for the ROCm platform COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/46222 from ROCmSoftwarePlatform:google_upstream_rocm_fix_xla_custom_call_test 823c406a07c9f2644ef82c0407f5f6f3c895428a PiperOrigin-RevId: 350730351 Change-Id: Id64bd074fda2b185e4791c926bd59b944db60a11
This commit is contained in:
parent
cf603aa4f9
commit
277e22a015
@ -119,8 +119,6 @@ tf_cc_test(
|
|||||||
tags = tf_cuda_tests_tags(),
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||||
] + if_cuda_is_configured([
|
|
||||||
"@local_config_cuda//cuda:cuda_headers",
|
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
@ -129,6 +127,10 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
|
] + if_cuda_is_configured([
|
||||||
|
"@local_config_cuda//cuda:cuda_headers",
|
||||||
|
]) + if_rocm_is_configured([
|
||||||
|
"@local_config_rocm//rocm:rocm_headers",
|
||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,9 +13,15 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
#include "third_party/gpus/cuda/include/cuda.h"
|
#include "third_party/gpus/cuda/include/cuda.h"
|
||||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||||
#include "third_party/gpus/cuda/include/driver_types.h"
|
#include "third_party/gpus/cuda/include/driver_types.h"
|
||||||
|
#define PLATFORM "CUDA"
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
#include "rocm/include/hip/hip_runtime.h"
|
||||||
|
#define PLATFORM "ROCM"
|
||||||
|
#endif
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||||
@ -23,6 +29,23 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/stream_executor/gpu/gpu_types.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#define gpuSuccess cudaSuccess
|
||||||
|
#define gpuMemcpyAsync cudaMemcpyAsync
|
||||||
|
#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
|
||||||
|
#define gpuMemcpy cudaMemcpy
|
||||||
|
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
|
||||||
|
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
|
||||||
|
#elif TENSORFLOW_USE_ROCM
|
||||||
|
#define gpuSuccess hipSuccess
|
||||||
|
#define gpuMemcpyAsync hipMemcpyAsync
|
||||||
|
#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||||
|
#define gpuMemcpy hipMemcpy
|
||||||
|
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||||
|
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
|
||||||
|
#endif
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
@ -30,11 +53,11 @@ namespace {
|
|||||||
class CustomCallTest : public ClientLibraryTestBase {};
|
class CustomCallTest : public ClientLibraryTestBase {};
|
||||||
|
|
||||||
bool is_invoked_called = false;
|
bool is_invoked_called = false;
|
||||||
void Callback_IsInvoked(CUstream /*stream*/, void** /*buffers*/,
|
void Callback_IsInvoked(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
|
||||||
const char* /*opaque*/, size_t /*opaque_len*/) {
|
const char* /*opaque*/, size_t /*opaque_len*/) {
|
||||||
is_invoked_called = true;
|
is_invoked_called = true;
|
||||||
}
|
}
|
||||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, "CUDA");
|
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, PLATFORM);
|
||||||
|
|
||||||
TEST_F(CustomCallTest, IsInvoked) {
|
TEST_F(CustomCallTest, IsInvoked) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
@ -53,16 +76,15 @@ TEST_F(CustomCallTest, UnknownTarget) {
|
|||||||
/*opaque=*/"");
|
/*opaque=*/"");
|
||||||
ASSERT_FALSE(Execute(&b, {}).ok());
|
ASSERT_FALSE(Execute(&b, {}).ok());
|
||||||
}
|
}
|
||||||
|
void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers,
|
||||||
void Callback_Memcpy(CUstream stream, void** buffers, const char* /*opaque*/,
|
const char* /*opaque*/, size_t /*opaque_len*/) {
|
||||||
size_t /*opaque_len*/) {
|
|
||||||
void* src = buffers[0];
|
void* src = buffers[0];
|
||||||
void* dst = buffers[1];
|
void* dst = buffers[1];
|
||||||
auto err = cudaMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
|
auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
|
||||||
cudaMemcpyDeviceToDevice, stream);
|
gpuMemcpyDeviceToDevice, stream);
|
||||||
ASSERT_EQ(err, cudaSuccess);
|
ASSERT_EQ(err, gpuSuccess);
|
||||||
}
|
}
|
||||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, "CUDA");
|
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM);
|
||||||
TEST_F(CustomCallTest, Memcpy) {
|
TEST_F(CustomCallTest, Memcpy) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
CustomCall(&b, "Callback_Memcpy",
|
CustomCall(&b, "Callback_Memcpy",
|
||||||
@ -74,12 +96,12 @@ TEST_F(CustomCallTest, Memcpy) {
|
|||||||
|
|
||||||
// Check that opaque handles nulls within the string.
|
// Check that opaque handles nulls within the string.
|
||||||
std::string& kExpectedOpaque = *new std::string("abc\0def", 7);
|
std::string& kExpectedOpaque = *new std::string("abc\0def", 7);
|
||||||
void Callback_Opaque(CUstream /*stream*/, void** /*buffers*/,
|
void Callback_Opaque(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
|
||||||
const char* opaque, size_t opaque_len) {
|
const char* opaque, size_t opaque_len) {
|
||||||
std::string opaque_str(opaque, opaque_len);
|
std::string opaque_str(opaque, opaque_len);
|
||||||
ASSERT_EQ(opaque_str, kExpectedOpaque);
|
ASSERT_EQ(opaque_str, kExpectedOpaque);
|
||||||
}
|
}
|
||||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, "CUDA");
|
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, PLATFORM);
|
||||||
TEST_F(CustomCallTest, Opaque) {
|
TEST_F(CustomCallTest, Opaque) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
CustomCall(&b, "Callback_Opaque", /*operands=*/{},
|
CustomCall(&b, "Callback_Opaque", /*operands=*/{},
|
||||||
@ -87,7 +109,7 @@ TEST_F(CustomCallTest, Opaque) {
|
|||||||
TF_ASSERT_OK(Execute(&b, {}).status());
|
TF_ASSERT_OK(Execute(&b, {}).status());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Callback_SubBuffers(CUstream stream, void** buffers,
|
void Callback_SubBuffers(se::gpu::GpuStreamHandle stream, void** buffers,
|
||||||
const char* /*opaque*/, size_t /*opaque_len*/) {
|
const char* /*opaque*/, size_t /*opaque_len*/) {
|
||||||
// `buffers` is a flat array containing device pointers to the following.
|
// `buffers` is a flat array containing device pointers to the following.
|
||||||
//
|
//
|
||||||
@ -103,16 +125,16 @@ void Callback_SubBuffers(CUstream stream, void** buffers,
|
|||||||
|
|
||||||
// Set output leaf buffers, copying data from the corresponding same-sized
|
// Set output leaf buffers, copying data from the corresponding same-sized
|
||||||
// inputs.
|
// inputs.
|
||||||
cudaMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
|
gpuMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
|
||||||
cudaMemcpyDeviceToDevice, stream);
|
gpuMemcpyDeviceToDevice, stream);
|
||||||
cudaMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
|
gpuMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
|
||||||
cudaMemcpyDeviceToDevice, stream);
|
gpuMemcpyDeviceToDevice, stream);
|
||||||
cudaMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
|
gpuMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
|
||||||
cudaMemcpyDeviceToDevice, stream);
|
gpuMemcpyDeviceToDevice, stream);
|
||||||
cudaMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
|
gpuMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
|
||||||
cudaMemcpyDeviceToDevice, stream);
|
gpuMemcpyDeviceToDevice, stream);
|
||||||
}
|
}
|
||||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, "CUDA");
|
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, PLATFORM);
|
||||||
TEST_F(CustomCallTest, SubBuffers) {
|
TEST_F(CustomCallTest, SubBuffers) {
|
||||||
XlaBuilder b(TestName());
|
XlaBuilder b(TestName());
|
||||||
CustomCall(&b, "Callback_SubBuffers", /*operands=*/
|
CustomCall(&b, "Callback_SubBuffers", /*operands=*/
|
||||||
|
Loading…
x
Reference in New Issue
Block a user