Merge pull request #40305 from ROCmSoftwarePlatform:google-upstream-kernel-helper-test

PiperOrigin-RevId: 317684035
Change-Id: I743fb0d2c8ba73cbe66a399c921b28d09762ef39
This commit is contained in:
TensorFlower Gardener 2020-06-22 11:51:48 -07:00
commit 809e482247
2 changed files with 85 additions and 47 deletions

View File

@ -53,6 +53,8 @@ using gpuEvent_t = cudaEvent_t;
#define gpuEventCreate cudaEventCreate
#define gpuEventCreateWithFlags cudaEventCreateWithFlags
#define gpuEventDisableTiming cudaEventDisableTiming
#define gpuDeviceSynchronize cudaDeviceSynchronize
#define gpuFree cudaFree
#elif TENSORFLOW_USE_ROCM
using gpuFloatComplex = hipFloatComplex;
using gpuDoubleComplex = hipDoubleComplex;
@ -68,6 +70,8 @@ using cudaError_t = int;
#define gpuEventCreate hipEventCreate
#define gpuEventCreateWithFlags hipEventCreateWithFlags
#define gpuEventDisableTiming hipEventDisableTiming
#define gpuDeviceSynchronize hipDeviceSynchronize
#define gpuFree hipFree
static std::string cudaGetErrorString(int err) { return std::to_string(err); }
#endif

View File

@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#define EIGEN_USE_GPU
#include <time.h>
#include <numeric>
#include "tensorflow/core/lib/core/status_test_util.h"
@ -25,14 +27,14 @@ limitations under the License.
#define CUDA_EXPECT_SUCCESS \
{ \
cudaDeviceSynchronize(); \
gpuDeviceSynchronize(); \
cudaError_t err = cudaGetLastError(); \
EXPECT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
}
#define CUDA_ASSERT_SUCCESS \
{ \
cudaDeviceSynchronize(); \
gpuDeviceSynchronize(); \
cudaError_t err = cudaGetLastError(); \
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err); \
}
@ -94,8 +96,7 @@ __global__ void Count3D(Gpu3DLaunchConfig config, int bufsize,
}
}
__global__ void CudaShuffleGetSrcLaneTest(
unsigned* __restrict__ failure_count) {
__global__ void GpuShuffleGetSrcLaneTest(unsigned* __restrict__ failure_count) {
unsigned lane_id = GpuLaneId();
for (int width = warpSize; width > 1; width /= 2) {
auto check_result = [&](const char* op_name, int param, unsigned actual,
@ -103,31 +104,38 @@ __global__ void CudaShuffleGetSrcLaneTest(
if (actual != expected) {
printf("Cuda%sGetSrcLane(%d, %d) for lane %d returned %d, not %d\n",
op_name, param, width, lane_id, actual, expected);
CudaAtomicAdd(failure_count, 1);
GpuAtomicAdd(failure_count, 1);
}
};
for (int src_lane = -warpSize; src_lane <= warpSize; ++src_lane) {
unsigned actual_lane = detail::CudaShuffleGetSrcLane(src_lane, width);
#if TENSORFLOW_USE_ROCM
if (src_lane < 0 || src_lane >= width) continue;
#endif
unsigned actual_lane = detail::GpuShuffleGetSrcLane(src_lane, width);
unsigned expect_lane =
CudaShuffleSync(kCudaWarpAll, lane_id, src_lane, width);
GpuShuffleSync(kCudaWarpAll, lane_id, src_lane, width);
check_result("Shuffle", src_lane, actual_lane, expect_lane);
}
for (unsigned delta = 0; delta <= warpSize; ++delta) {
unsigned actual_lane = detail::CudaShuffleUpGetSrcLane(delta, width);
unsigned actual_lane = detail::GpuShuffleUpGetSrcLane(delta, width);
unsigned expect_lane =
CudaShuffleUpSync(kCudaWarpAll, lane_id, delta, width);
GpuShuffleUpSync(kCudaWarpAll, lane_id, delta, width);
check_result("ShuffleUp", delta, actual_lane, expect_lane);
}
for (unsigned delta = 0; delta <= warpSize; ++delta) {
unsigned actual_lane = detail::CudaShuffleDownGetSrcLane(delta, width);
unsigned actual_lane = detail::GpuShuffleDownGetSrcLane(delta, width);
unsigned expect_lane =
CudaShuffleDownSync(kCudaWarpAll, lane_id, delta, width);
GpuShuffleDownSync(kCudaWarpAll, lane_id, delta, width);
check_result("ShuffleDown", delta, actual_lane, expect_lane);
}
for (int lane_lane = warpSize; lane_lane > 0; lane_lane /= 2) {
unsigned actual_lane = detail::CudaShuffleXorGetSrcLane(lane_lane, width);
unsigned actual_lane = detail::GpuShuffleXorGetSrcLane(lane_lane, width);
unsigned expect_lane =
CudaShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width);
GpuShuffleXorSync(kCudaWarpAll, lane_id, lane_lane, width);
check_result("ShuffleXor", lane_lane, actual_lane, expect_lane);
}
}
@ -137,19 +145,32 @@ __global__ void CudaShuffleGetSrcLaneTest(
class GpuLaunchConfigTest : public ::testing::Test {
protected:
const int bufsize = 1024;
static const int bufsize = 1024;
int* outbuf = nullptr;
int* outbuf_host = nullptr;
int hostbuf[bufsize];
Eigen::GpuStreamDevice stream;
Eigen::GpuDevice d = Eigen::GpuDevice(&stream);
void copyToHost() {
#if TENSORFLOW_USE_ROCM
hipMemcpy(hostbuf, outbuf, sizeof(int) * bufsize, hipMemcpyDeviceToHost);
#endif
}
virtual void SetUp() {
#if GOOGLE_CUDA
cudaError_t err = cudaMallocManaged(&outbuf, sizeof(int) * bufsize);
outbuf_host = outbuf;
#else
cudaError_t err = hipMalloc(&outbuf, sizeof(int) * bufsize);
outbuf_host = hostbuf;
#endif
ASSERT_EQ(cudaSuccess, err) << cudaGetErrorString(err);
}
virtual void TearDown() {
cudaDeviceSynchronize();
cudaFree(outbuf);
gpuDeviceSynchronize();
gpuFree(outbuf);
outbuf = nullptr;
}
};
@ -158,28 +179,32 @@ TEST_F(GpuLaunchConfigTest, GetGpuLaunchConfig) {
GpuLaunchConfig cfg;
// test valid inputs
#define TEST_LAUNCH_PARAMETER(work_element_count) \
cfg = GetGpuLaunchConfig(bufsize, d); \
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
cfg.thread_per_block, 0, d.stream(), cfg, \
outbuf)); \
CUDA_ASSERT_SUCCESS \
cfg = GetGpuLaunchConfig(work_element_count, d); \
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
0, d.stream(), cfg, bufsize, outbuf)); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0)); \
\
cfg = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
cfg.thread_per_block, 0, d.stream(), cfg, \
outbuf)); \
CUDA_ASSERT_SUCCESS \
cfg = GetGpuLaunchConfig(work_element_count, d, Count1D, 0, 0); \
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
0, d.stream(), cfg, bufsize, outbuf)); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(work_element_count, std::accumulate(outbuf, outbuf + bufsize, 0))
#define TEST_LAUNCH_PARAMETER(work_element_count) \
cfg = GetGpuLaunchConfig(bufsize, d); \
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
cfg.thread_per_block, 0, d.stream(), cfg, \
outbuf)); \
CUDA_ASSERT_SUCCESS \
cfg = GetGpuLaunchConfig(work_element_count, d); \
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
0, d.stream(), cfg, bufsize, outbuf)); \
CUDA_EXPECT_SUCCESS \
copyToHost(); \
EXPECT_EQ(work_element_count, \
std::accumulate(outbuf_host, outbuf_host + bufsize, 0)); \
\
cfg = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
TF_CHECK_OK(GpuLaunchKernel(SetOutbufZero, cfg.block_count, \
cfg.thread_per_block, 0, d.stream(), cfg, \
outbuf)); \
CUDA_ASSERT_SUCCESS \
cfg = GetGpuLaunchConfig(work_element_count, d, Count1D, 0, 0); \
TF_CHECK_OK(GpuLaunchKernel(Count1D, cfg.block_count, cfg.thread_per_block, \
0, d.stream(), cfg, bufsize, outbuf)); \
CUDA_EXPECT_SUCCESS \
copyToHost(); \
EXPECT_EQ(work_element_count, \
std::accumulate(outbuf_host, outbuf_host + bufsize, 0));
TEST_LAUNCH_PARAMETER(128);
TEST_LAUNCH_PARAMETER(129);
@ -221,7 +246,9 @@ TEST_F(GpuLaunchConfigTest, GetGpu2DLaunchConfig) {
TF_EXPECT_OK(GpuLaunchKernel(Count2D, cfg.block_count, cfg.thread_per_block, \
0, d.stream(), cfg, bufsize, outbuf)); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0)); \
copyToHost(); \
EXPECT_EQ(dimx* dimy, \
std::accumulate(outbuf_host, outbuf_host + bufsize, 0)); \
\
cfg1d = GetGpuLaunchConfig(bufsize, d, SetOutbufZero, 0, 0); \
TF_EXPECT_OK(GpuLaunchKernel(SetOutbufZero, cfg1d.block_count, \
@ -232,7 +259,8 @@ TEST_F(GpuLaunchConfigTest, GetGpu2DLaunchConfig) {
TF_EXPECT_OK(GpuLaunchKernel(Count2D, cfg.block_count, cfg.thread_per_block, \
0, d.stream(), cfg, bufsize, outbuf)); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(dimx* dimy, std::accumulate(outbuf, outbuf + bufsize, 0))
copyToHost(); \
EXPECT_EQ(dimx* dimy, std::accumulate(outbuf_host, outbuf_host + bufsize, 0))
TEST_LAUNCH_PARAMETER(128, 128);
TEST_LAUNCH_PARAMETER(129, 64);
@ -263,7 +291,9 @@ TEST_F(GpuLaunchConfigTest, GetGpu3DLaunchConfig) {
TF_EXPECT_OK(GpuLaunchKernel(Count3D, cfg.block_count, cfg.thread_per_block, \
0, d.stream(), cfg, bufsize, outbuf)); \
CUDA_EXPECT_SUCCESS \
EXPECT_EQ(dimx* dimy* dimz, std::accumulate(outbuf, outbuf + bufsize, 0))
copyToHost(); \
EXPECT_EQ(dimx* dimy* dimz, \
std::accumulate(outbuf_host, outbuf_host + bufsize, 0))
TEST_LAUNCH_PARAMETER(128, 128, 128);
TEST_LAUNCH_PARAMETER(129, 64, 1024);
@ -282,15 +312,19 @@ TEST_F(GpuLaunchConfigTest, GetGpu3DLaunchConfig) {
TEST(CudaDeviceFunctionsTest, ShuffleGetSrcLane) {
unsigned* failure_count;
#if GOOGLE_CUDA
ASSERT_EQ(cudaMallocManaged(&failure_count, sizeof(unsigned)), cudaSuccess);
#else
ASSERT_EQ(hipHostMalloc(&failure_count, sizeof(unsigned), 0), cudaSuccess);
#endif
*failure_count = 0;
TF_EXPECT_OK(GpuLaunchKernel(CudaShuffleGetSrcLaneTest, 1, 32, 0, nullptr,
failure_count));
ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess);
TF_EXPECT_OK(GpuLaunchKernel(GpuShuffleGetSrcLaneTest, 1, TF_RED_WARPSIZE, 0,
nullptr, failure_count));
ASSERT_EQ(gpuDeviceSynchronize(), cudaSuccess);
ASSERT_EQ(*failure_count, 0);
cudaFree(failure_count);
gpuFree(failure_count);
}
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM