Merge pull request #40305 from ROCmSoftwarePlatform:google-upstream-kernel-helper-test
PiperOrigin-RevId: 317684035 Change-Id: I743fb0d2c8ba73cbe66a399c921b28d09762ef39
This commit is contained in:
commit
809e482247
|
@ -53,6 +53,8 @@ using gpuEvent_t = cudaEvent_t;
|
|||
#define gpuEventCreate cudaEventCreate
|
||||
#define gpuEventCreateWithFlags cudaEventCreateWithFlags
|
||||
#define gpuEventDisableTiming cudaEventDisableTiming
|
||||
#define gpuDeviceSynchronize cudaDeviceSynchronize
|
||||
#define gpuFree cudaFree
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
using gpuFloatComplex = hipFloatComplex;
|
||||
using gpuDoubleComplex = hipDoubleComplex;
|
||||
|
@ -68,6 +70,8 @@ using cudaError_t = int;
|
|||
#define gpuEventCreate hipEventCreate
|
||||
#define gpuEventCreateWithFlags hipEventCreateWithFlags
|
||||
#define gpuEventDisableTiming hipEventDisableTiming
|
||||
#define gpuDeviceSynchronize hipDeviceSynchronize
|
||||
#define gpuFree hipFree
|
||||
static std::string cudaGetErrorString(int err) { return std::to_string(err); }
|
||||
#endif
|
||||
|
||||
|
|
|
@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include <time.h>
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
@ -25,14 +27,14 @@ limitations under the License.
|
|||
|
||||
#define CUDA_EXPECT_SUCCESS \
|
||||
{ \
|
||||
cudaDeviceSynchronize(); \
|
||||
gpuDeviceSynchronize(); \
|
||||
cudaError_t err = cudaGetLastError(); \
|
||||
EXPECT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
|
||||
}
|
||||
|
||||
#define CUDA_ASSERT_SUCCESS \
|
||||
{ \
|
||||
cudaDeviceSynchronize(); \
|
||||
gpuDeviceSynchronize(); \
|
||||
cudaError_t err = cudaGetLastError(); \
|
||||
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
|
||||
}
|
||||
|
@ -94,8 +96,7 @@ __global__ void Count3D(Gpu3DLaunchConfig config, int bufsize,
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void CudaShuffleGetSrcLaneTest(
|
||||
unsigned* __restrict__ failure_count) {
|
||||
__global__ void GpuShuffleGetSrcLaneTest(unsigned* __restrict__ failure_count) {
|
||||
unsigned lane_id = GpuLaneId();
|
||||
for (int width = warpSize; width > 1; width /= 2) {
|
||||
auto check_result = [&](const char* op_name, int param, unsigned actual,
|
||||
|
@ -103,31 +104,38 @@ __global__ void CudaShuffleGetSrcLaneTest(
|
|||
if (actual != expected) {
|
||||
printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n",
|
||||
op_name, param, width, lane_id, actual, expected);
|
||||
CudaAtomicAdd(failure_count, 1);
|
||||
GpuAtomicAdd(failure_count, 1);
|
||||
}
|
||||
};
|
||||
|
||||
for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) {
|
||||
unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width);
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
if (src_lane < 0 || src_lane >= width) continue;
|
||||
#endif
|
||||
unsigned actual_lane = detail::GpuShuffleGetSrcLane(src_lane, width);
|
||||
unsigned expect_lane =
|
||||
CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width);
|
||||
GpuShuffleSync(kCudaWarpAll, lane_id, src_lane, width);
|
||||
check_result("Shuffle", src_lane, actual_lane, expect_lane);
|
||||
}
|
||||
|
||||
for (unsigned delta = 0; delta <= warpSize; ++delta) {
|
||||
unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width);
|
||||
unsigned actual_lane = detail::GpuShuffleUpGetSrcLane(delta, width);
|
||||
unsigned expect_lane =
|
||||
CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width);
|
||||
GpuShuffleUpSync(kCudaWarpAll, lane_id, delta, width);
|
||||
check_result("ShuffleUp", delta, actual_lane, expect_lane);
|
||||
}
|
||||
|
||||
for (unsigned delta = 0; delta <= warpSize; ++delta) {
|
||||
unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width);
|
||||
unsigned actual_lane = detail::GpuShuffleDownGetSrcLane(delta, width);
|
||||
unsigned expect_lane =
|
||||
CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width);
|
||||
GpuShuffleDownSync(kCudaWarpAll, lane_id, delta, width);
|
||||
check_result("ShuffleDown", delta, actual_lane, expect_lane);
|
||||
}
|
||||
|
||||
for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) {
|
||||
unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width);
|
||||
unsigned actual_lane = detail::GpuShuffleXorGetSrcLane(lane_lane, width);
|
||||
unsigned expect_lane =
|
||||
CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width);
|
||||
GpuShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width);
|
||||
check_result("ShuffleXor", lane_lane, actual_lane, expect_lane);
|
||||
}
|
||||
}
|
||||
|
@ -137,19 +145,32 @@ __global__ void CudaShuffleGetSrcLaneTest(
|
|||
|
||||
class GpuLaunchConfigTest : public ::testing::Test {
|
||||
protected:
|
||||
const int bufsize = 1024;
|
||||
static const int bufsize = 1024;
|
||||
int* outbuf = nullptr;
|
||||
int* outbuf_host = nullptr;
|
||||
int hostbuf[bufsize];
|
||||
Eigen::GpuStreamDevice stream;
|
||||
Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
|
||||
|
||||
void copyToHost() {
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
hipMemcpy(hostbuf, outbuf, sizeof(int) * bufsize, hipMemcpyDeviceToHost);
|
||||
#endif
|
||||
}
|
||||
virtual void SetUp() {
|
||||
#if GOOGLE_CUDA
|
||||
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
|
||||
outbuf_host = outbuf;
|
||||
#else
|
||||
cudaError_t err = hipMalloc(&outbuf, sizeof(int) * bufsize);
|
||||
outbuf_host = hostbuf;
|
||||
#endif
|
||||
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err);
|
||||
}
|
||||
|
||||
virtual void TearDown() {
|
||||
cudaDeviceSynchronize();
|
||||
cudaFree(outbuf);
|
||||
gpuDeviceSynchronize();
|
||||
gpuFree(outbuf);
|
||||
outbuf = nullptr;
|
||||
}
|
||||
};
|
||||
|
@ -158,28 +179,32 @@ TEST_F(GpuLaunchConfigTest, GetGpuLaunchConfig) {
|
|||
GpuLaunchConfig cfg;
|
||||
|
||||
// test valid inputs
|
||||
#define TEST_LAUNCH_PARAMETER(work_element_count) \
|
||||
cfg = GetGpuLaunchConfig(bufsize, d); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
|
||||
cfg.thread_per_block, 0, d.stream(), cfg, \
|
||||
outbuf)); \
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetGpuLaunchConfig(work_element_count, d); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
|
||||
0, d.stream(), cfg, bufsize, outbuf)); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)); \
|
||||
\
|
||||
cfg = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
|
||||
cfg.thread_per_block, 0, d.stream(), cfg, \
|
||||
outbuf)); \
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetGpuLaunchConfig(work_element_count, d, Count1D, 0, 0); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
|
||||
0, d.stream(), cfg, bufsize, outbuf)); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0))
|
||||
#define TEST_LAUNCH_PARAMETER(work_element_count) \
|
||||
cfg = GetGpuLaunchConfig(bufsize, d); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
|
||||
cfg.thread_per_block, 0, d.stream(), cfg, \
|
||||
outbuf)); \
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetGpuLaunchConfig(work_element_count, d); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
|
||||
0, d.stream(), cfg, bufsize, outbuf)); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
copyToHost(); \
|
||||
EXPECT_EQ(work_element_count, \
|
||||
std::accumulate(outbuf_host, outbuf_host + bufsize, 0)); \
|
||||
\
|
||||
cfg = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
|
||||
cfg.thread_per_block, 0, d.stream(), cfg, \
|
||||
outbuf)); \
|
||||
CUDA_ASSERT_SUCCESS \
|
||||
cfg = GetGpuLaunchConfig(work_element_count, d, Count1D, 0, 0); \
|
||||
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
|
||||
0, d.stream(), cfg, bufsize, outbuf)); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
copyToHost(); \
|
||||
EXPECT_EQ(work_element_count, \
|
||||
std::accumulate(outbuf_host, outbuf_host + bufsize, 0));
|
||||
|
||||
TEST_LAUNCH_PARAMETER(128);
|
||||
TEST_LAUNCH_PARAMETER(129);
|
||||
|
@ -221,7 +246,9 @@ TEST_F(GpuLaunchConfigTest, GetGpu2DLaunchConfig) {
|
|||
TF_EXPECT_OK(GpuLaunchKernel(Count2D, cfg.block_count, cfg.thread_per_block, \
|
||||
0, d.stream(), cfg, bufsize, outbuf)); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \
|
||||
copyToHost(); \
|
||||
EXPECT_EQ(dimx* dimy, \
|
||||
std::accumulate(outbuf_host, outbuf_host + bufsize, 0)); \
|
||||
\
|
||||
cfg1d = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
|
||||
TF_EXPECT_OK(GpuLaunchKernel(SetOutbufZero, cfg1d.block_count, \
|
||||
|
@ -232,7 +259,8 @@ TEST_F(GpuLaunchConfigTest, GetGpu2DLaunchConfig) {
|
|||
TF_EXPECT_OK(GpuLaunchKernel(Count2D, cfg.block_count, cfg.thread_per_block, \
|
||||
0, d.stream(), cfg, bufsize, outbuf)); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0))
|
||||
copyToHost(); \
|
||||
EXPECT_EQ(dimx* dimy, std::accumulate(outbuf_host, outbuf_host + bufsize, 0))
|
||||
|
||||
TEST_LAUNCH_PARAMETER(128, 128);
|
||||
TEST_LAUNCH_PARAMETER(129, 64);
|
||||
|
@ -263,7 +291,9 @@ TEST_F(GpuLaunchConfigTest, GetGpu3DLaunchConfig) {
|
|||
TF_EXPECT_OK(GpuLaunchKernel(Count3D, cfg.block_count, cfg.thread_per_block, \
|
||||
0, d.stream(), cfg, bufsize, outbuf)); \
|
||||
CUDA_EXPECT_SUCCESS \
|
||||
EXPECT_EQ(dimx* dimy* dimz, std::accumulate(outbuf, outbuf + bufsize, 0))
|
||||
copyToHost(); \
|
||||
EXPECT_EQ(dimx* dimy* dimz, \
|
||||
std::accumulate(outbuf_host, outbuf_host + bufsize, 0))
|
||||
|
||||
TEST_LAUNCH_PARAMETER(128, 128, 128);
|
||||
TEST_LAUNCH_PARAMETER(129, 64, 1024);
|
||||
|
@ -282,15 +312,19 @@ TEST_F(GpuLaunchConfigTest, GetGpu3DLaunchConfig) {
|
|||
|
||||
TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) {
|
||||
unsigned* failure_count;
|
||||
#if GOOGLE_CUDA
|
||||
ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess);
|
||||
#else
|
||||
ASSERT_EQ(hipHostMalloc(&failure_count, sizeof(unsigned), 0), cudaSuccess);
|
||||
#endif
|
||||
*failure_count = 0;
|
||||
TF_EXPECT_OK(GpuLaunchKernel(CudaShuffleGetSrcLaneTest, 1, 32, 0, nullptr,
|
||||
failure_count));
|
||||
ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess);
|
||||
TF_EXPECT_OK(GpuLaunchKernel(GpuShuffleGetSrcLaneTest, 1, TF_RED_WARPSIZE, 0,
|
||||
nullptr, failure_count));
|
||||
ASSERT_EQ(gpuDeviceSynchronize(), cudaSuccess);
|
||||
ASSERT_EQ(*failure_count, 0);
|
||||
cudaFree(failure_count);
|
||||
gpuFree(failure_count);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
|
Loading…
Reference in New Issue