PR #46222: [ROCm] Updating XLA custom_call_test to enable it for the ROCm platform
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/46222 -------------------------- /cc @chsigg @cheshire @nvining-work Copybara import of the project: -- 823c406a07c9f2644ef82c0407f5f6f3c895428a by Deven Desai <deven.desai.amd@gmail.com>: [ROCm] Updating XLA custom_call_test to enable it for the ROCm platform COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/tensorflow/pull/46222 from ROCmSoftwarePlatform:google_upstream_rocm_fix_xla_custom_call_test 823c406a07c9f2644ef82c0407f5f6f3c895428a PiperOrigin-RevId: 350730351 Change-Id: Id64bd074fda2b185e4791c926bd59b944db60a11
This commit is contained in:
parent
cf603aa4f9
commit
277e22a015
@ -119,8 +119,6 @@ tf_cc_test(
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
] + if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
@ -129,6 +127,10 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/core:test",
|
||||
] + if_cuda_is_configured([
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
]) + if_rocm_is_configured([
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
]),
|
||||
)
|
||||
|
||||
|
@ -13,9 +13,15 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cuda/include/cuda.h"
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/driver_types.h"
|
||||
#define PLATFORM "CUDA"
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#include "rocm/include/hip/hip_runtime.h"
|
||||
#define PLATFORM "ROCM"
|
||||
#endif
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
|
||||
@ -23,6 +29,23 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_types.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#define gpuSuccess cudaSuccess
|
||||
#define gpuMemcpyAsync cudaMemcpyAsync
|
||||
#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
|
||||
#define gpuMemcpy cudaMemcpy
|
||||
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
|
||||
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
#define gpuSuccess hipSuccess
|
||||
#define gpuMemcpyAsync hipMemcpyAsync
|
||||
#define gpuMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||
#define gpuMemcpy hipMemcpy
|
||||
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
@ -30,11 +53,11 @@ namespace {
|
||||
class CustomCallTest : public ClientLibraryTestBase {};
|
||||
|
||||
bool is_invoked_called = false;
|
||||
void Callback_IsInvoked(CUstream /*stream*/, void** /*buffers*/,
|
||||
void Callback_IsInvoked(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
|
||||
const char* /*opaque*/, size_t /*opaque_len*/) {
|
||||
is_invoked_called = true;
|
||||
}
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_IsInvoked, PLATFORM);
|
||||
|
||||
TEST_F(CustomCallTest, IsInvoked) {
|
||||
XlaBuilder b(TestName());
|
||||
@ -53,16 +76,15 @@ TEST_F(CustomCallTest, UnknownTarget) {
|
||||
/*opaque=*/"");
|
||||
ASSERT_FALSE(Execute(&b, {}).ok());
|
||||
}
|
||||
|
||||
void Callback_Memcpy(CUstream stream, void** buffers, const char* /*opaque*/,
|
||||
size_t /*opaque_len*/) {
|
||||
void Callback_Memcpy(se::gpu::GpuStreamHandle stream, void** buffers,
|
||||
const char* /*opaque*/, size_t /*opaque_len*/) {
|
||||
void* src = buffers[0];
|
||||
void* dst = buffers[1];
|
||||
auto err = cudaMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
ASSERT_EQ(err, cudaSuccess);
|
||||
auto err = gpuMemcpyAsync(dst, src, /*count=*/sizeof(float) * 128,
|
||||
gpuMemcpyDeviceToDevice, stream);
|
||||
ASSERT_EQ(err, gpuSuccess);
|
||||
}
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Memcpy, PLATFORM);
|
||||
TEST_F(CustomCallTest, Memcpy) {
|
||||
XlaBuilder b(TestName());
|
||||
CustomCall(&b, "Callback_Memcpy",
|
||||
@ -74,12 +96,12 @@ TEST_F(CustomCallTest, Memcpy) {
|
||||
|
||||
// Check that opaque handles nulls within the string.
|
||||
std::string& kExpectedOpaque = *new std::string("abc\0def", 7);
|
||||
void Callback_Opaque(CUstream /*stream*/, void** /*buffers*/,
|
||||
void Callback_Opaque(se::gpu::GpuStreamHandle /*stream*/, void** /*buffers*/,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
std::string opaque_str(opaque, opaque_len);
|
||||
ASSERT_EQ(opaque_str, kExpectedOpaque);
|
||||
}
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_Opaque, PLATFORM);
|
||||
TEST_F(CustomCallTest, Opaque) {
|
||||
XlaBuilder b(TestName());
|
||||
CustomCall(&b, "Callback_Opaque", /*operands=*/{},
|
||||
@ -87,7 +109,7 @@ TEST_F(CustomCallTest, Opaque) {
|
||||
TF_ASSERT_OK(Execute(&b, {}).status());
|
||||
}
|
||||
|
||||
void Callback_SubBuffers(CUstream stream, void** buffers,
|
||||
void Callback_SubBuffers(se::gpu::GpuStreamHandle stream, void** buffers,
|
||||
const char* /*opaque*/, size_t /*opaque_len*/) {
|
||||
// `buffers` is a flat array containing device pointers to the following.
|
||||
//
|
||||
@ -103,16 +125,16 @@ void Callback_SubBuffers(CUstream stream, void** buffers,
|
||||
|
||||
// Set output leaf buffers, copying data from the corresponding same-sized
|
||||
// inputs.
|
||||
cudaMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
cudaMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
|
||||
cudaMemcpyDeviceToDevice, stream);
|
||||
gpuMemcpyAsync(buffers[4], buffers[3], 8 * sizeof(float),
|
||||
gpuMemcpyDeviceToDevice, stream);
|
||||
gpuMemcpyAsync(buffers[5], buffers[0], 128 * sizeof(float),
|
||||
gpuMemcpyDeviceToDevice, stream);
|
||||
gpuMemcpyAsync(buffers[6], buffers[1], 256 * sizeof(float),
|
||||
gpuMemcpyDeviceToDevice, stream);
|
||||
gpuMemcpyAsync(buffers[7], buffers[2], 1024 * sizeof(float),
|
||||
gpuMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, "CUDA");
|
||||
XLA_REGISTER_CUSTOM_CALL_TARGET(Callback_SubBuffers, PLATFORM);
|
||||
TEST_F(CustomCallTest, SubBuffers) {
|
||||
XlaBuilder b(TestName());
|
||||
CustomCall(&b, "Callback_SubBuffers", /*operands=*/
|
||||
|
Loading…
x
Reference in New Issue
Block a user