Merge pull request #29716 from ROCmSoftwarePlatform:google_upstream_rocm_platform_fix_190612
PiperOrigin-RevId: 253181917
This commit is contained in:
commit
ea4230d93b
@ -33,7 +33,7 @@ __global__ void MatrixDiagKernel(const int num_threads, const int num_rows,
|
||||
const int lower_diag_index,
|
||||
const int upper_diag_index, const T padding,
|
||||
const T* diag_ptr, T* output_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(index, num_threads) {
|
||||
GPU_1D_KERNEL_LOOP(index, num_threads) {
|
||||
const int batch_and_row_index = index / num_cols;
|
||||
const int col = index - batch_and_row_index * num_cols;
|
||||
const int batch = batch_and_row_index / num_rows;
|
||||
@ -69,7 +69,7 @@ struct MatrixDiag<GPUDevice, T> {
|
||||
}
|
||||
GpuLaunchConfig config =
|
||||
GetGpuLaunchConfig(batch_size * num_rows * num_cols, device);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
MatrixDiagKernel<T>, config.block_count, config.thread_per_block, 0,
|
||||
device.stream(), config.virtual_thread_count, num_rows, num_cols,
|
||||
num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding,
|
||||
@ -85,7 +85,7 @@ __global__ void MatrixDiagPartKernel(const int num_threads, const int num_rows,
|
||||
const int upper_diag_index,
|
||||
const T padding, const T* input_ptr,
|
||||
T* output_ptr) {
|
||||
CUDA_1D_KERNEL_LOOP(index, num_threads) {
|
||||
GPU_1D_KERNEL_LOOP(index, num_threads) {
|
||||
const int batch_and_mapped_diag_index = index / max_diag_len;
|
||||
const int index_in_the_diagonal =
|
||||
index - batch_and_mapped_diag_index * max_diag_len;
|
||||
@ -121,8 +121,8 @@ struct MatrixDiagPart<GPUDevice, T> {
|
||||
return;
|
||||
}
|
||||
GpuLaunchConfig config =
|
||||
GetCudaLaunchConfig(batch_size * num_diags * max_diag_len, device);
|
||||
TF_CHECK_OK(CudaLaunchKernel(
|
||||
GetGpuLaunchConfig(batch_size * num_diags * max_diag_len, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
MatrixDiagPartKernel<T>, config.block_count, config.thread_per_block, 0,
|
||||
device.stream(), config.virtual_thread_count, num_rows, num_cols,
|
||||
num_diags, max_diag_len, lower_diag_index, upper_diag_index, padding,
|
||||
|
Loading…
x
Reference in New Issue
Block a user