[ROCm] Misc updates for the ROCm platform
This PR contains the following updates (to sync contents from the ROCm TF fork) * *Cuda* -> *Gpu* renamings * clang-format changes
This commit is contained in:
parent
1fd5537b74
commit
a9efdf00de
@ -1909,4 +1909,4 @@ void GPUKernelTracker::RecordTerminated(uint64 queued_count) {
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -36,7 +36,7 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
// A Cuda kernel to check if each element is Inf or Nan. If any exists, the
|
||||
// relevant elements in abnormal_detected will be set
|
||||
template <typename T>
|
||||
__global__ void CheckNumericsKernel(const T *data, int size,
|
||||
__global__ void CheckNumericsKernel(const T* data, int size,
|
||||
int abnormal_detected[2]) {
|
||||
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int32 total_thread_count = gridDim.x * blockDim.x;
|
||||
@ -60,7 +60,7 @@ __global__ void CheckNumericsKernel(const T *data, int size,
|
||||
// abnormality in the given array
|
||||
template <typename T>
|
||||
struct CheckNumericsLaunch {
|
||||
void Run(const GPUDevice &d, const T *data, int size,
|
||||
void Run(const GPUDevice& d, const T* data, int size,
|
||||
int abnormal_detected[2]) {
|
||||
const int32 block_size = d.maxGpuThreadsPerBlock();
|
||||
const int32 num_blocks =
|
||||
|
@ -95,7 +95,7 @@ __global__ void DeterminantFromPivotedLUKernel(int nthreads, int n,
|
||||
// since this cheap O(n) kernel always follows an O(n^3) LU factorization.
|
||||
// The main purpose is to avoid having to copy the LU decomposition to
|
||||
// host memory.
|
||||
CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
|
||||
GPU_1D_KERNEL_LOOP(o_idx, nthreads) {
|
||||
// Initialize sign to (-1)^order.
|
||||
const int order = PermutationOrder(n, all_pivots + o_idx * n);
|
||||
Scalar prod_sign = order % 2 ? Scalar(-1) : Scalar(1);
|
||||
|
@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
@ -118,4 +118,4 @@ class GpuDeviceArrayOnHost {
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
// Contains structs and functions to be included in device code.
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_
|
||||
|
||||
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
|
||||
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
|
||||
@ -48,4 +48,4 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetGpuDeviceArrayOnDevice(
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
|
||||
#endif // TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_
|
||||
|
@ -66,7 +66,7 @@ __global__ void ComputePermutationFromTranspositionsKernel(
|
||||
// We only parallelize over batches here. Performance is not critical,
|
||||
// since this cheap O(num_rows) kernel always follows an O(num_rows^3)
|
||||
// LU factorization.
|
||||
CUDA_1D_KERNEL_LOOP(index, config.virtual_thread_count) {
|
||||
GPU_1D_KERNEL_LOOP(index, config.virtual_thread_count) {
|
||||
ComputePermutationFromTranspositions(
|
||||
num_rows, all_pivots + index * num_rows,
|
||||
all_permutation_indices + index * num_rows);
|
||||
|
@ -63,10 +63,10 @@ __global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
|
||||
int64 ldu, const Scalar* M,
|
||||
const Scalar* U, const Scalar* S,
|
||||
Scalar* V) {
|
||||
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
|
||||
CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
|
||||
GPU_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
|
||||
GPU_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
|
||||
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
|
||||
CudaAtomicAdd(V + batch, v);
|
||||
GpuAtomicAdd(V + batch, v);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -75,7 +75,7 @@ __global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
|
||||
// V[i] = V[i]>=0 ? 1 : 0
|
||||
template <class Scalar>
|
||||
__global__ void ExtractSignOfVKernel(GpuLaunchConfig config, Scalar* V) {
|
||||
CUDA_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
|
||||
GPU_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
|
||||
V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1);
|
||||
}
|
||||
}
|
||||
@ -195,7 +195,7 @@ class SvdOpGpu : public AsyncOpKernel {
|
||||
// 1. compute the (batched) sum
|
||||
const GPUDevice& d = context->eigen_device<GPUDevice>();
|
||||
d.memset(outputV_ptr, 0, batch_size * sizeof(Scalar));
|
||||
Gpu2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d);
|
||||
Gpu2DLaunchConfig cfg2D = GetGpu2DLaunchConfig(batch_size, m, d);
|
||||
TF_CHECK_OK(GpuLaunchKernel(ComputeValueOfVKernel<Scalar>,
|
||||
cfg2D.block_count, cfg2D.thread_per_block, 0,
|
||||
d.stream(), cfg2D, m, full_matrices_ ? m : p,
|
||||
|
@ -77,7 +77,7 @@ class TridiagonalMatMulOpGpu : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
|
||||
|
||||
const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
|
||||
CudaLaunchConfig cfg = GetGpuLaunchConfig(1, device);
|
||||
GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
|
||||
TF_CHECK_OK(GpuLaunchKernel(
|
||||
TridiagonalMatMulKernel<Scalar>, cfg.block_count, cfg.thread_per_block,
|
||||
0, device.stream(), batch_size, m, n, superdiag.flat<Scalar>().data(),
|
||||
|
@ -48,7 +48,7 @@ __global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* diags,
|
||||
*not_invertible = true;
|
||||
return;
|
||||
}
|
||||
for (int i : CudaGridRangeX(num_rhs)) {
|
||||
for (int i : GpuGridRangeX(num_rhs)) {
|
||||
x[i] = rhs[i] / diags[1];
|
||||
}
|
||||
} else {
|
||||
@ -57,7 +57,7 @@ __global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* diags,
|
||||
*not_invertible = true;
|
||||
return;
|
||||
}
|
||||
for (int i : CudaGridRangeX(num_rhs)) {
|
||||
for (int i : GpuGridRangeX(num_rhs)) {
|
||||
x[i] = (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
|
||||
x[i + num_rhs] = (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user