Merge pull request #30611 from ROCmSoftwarePlatform:google_upstream_rocm_update_190711

PiperOrigin-RevId: 258814126
This commit is contained in:
TensorFlower Gardener 2019-07-18 12:25:03 -07:00
commit a5780f657f
11 changed files with 33 additions and 30 deletions

View File

@ -1909,4 +1909,4 @@ void GPUKernelTracker::RecordTerminated(uint64 queued_count) {
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -132,9 +132,9 @@ static void BM_gpu_float_int64(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(float) + sizeof(int64)));
testing::UseRealTime();
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
test::Benchmark("gpu", Cast<float, int64>(num)).Run(iters);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
test::Benchmark("sycl", Cast<float, int64>(num)).Run(iters);
#endif // TENSORFLOW_USE_SYCL
@ -155,9 +155,9 @@ static void BM_gpu_bool_float(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(bool) + sizeof(float)));
testing::UseRealTime();
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
test::Benchmark("gpu", Cast<bool, float>(num)).Run(iters);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#ifdef TENSORFLOW_USE_SYCL
test::Benchmark("sycl", Cast<bool, float>(num)).Run(iters);
#endif // TENSORFLOW_USE_SYCL
@ -205,9 +205,9 @@ static void BM_gpu_float_half(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(float) + sizeof(Eigen::half)));
testing::UseRealTime();
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
test::Benchmark("gpu", Cast<float, Eigen::half>(num)).Run(iters);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
BENCHMARK(BM_gpu_float_half)->Arg(64 << 10)->Arg(32 << 20);
@ -216,9 +216,9 @@ static void BM_gpu_half_float(int iters, int num) {
testing::BytesProcessed(static_cast<int64>(iters) * num *
(sizeof(float) + sizeof(Eigen::half)));
testing::UseRealTime();
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
test::Benchmark("gpu", Cast<Eigen::half, float>(num)).Run(iters);
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
}
BENCHMARK(BM_gpu_half_float)->Arg(64 << 10)->Arg(32 << 20);

View File

@ -36,7 +36,7 @@ typedef Eigen::GpuDevice GPUDevice;
// A Cuda kernel to check if each element is Inf or Nan. If any exists, the
// relevant elements in abnormal_detected will be set
template <typename T>
__global__ void CheckNumericsKernel(const T *data, int size,
__global__ void CheckNumericsKernel(const T* data, int size,
int abnormal_detected[2]) {
const int32 thread_id = blockIdx.x * blockDim.x + threadIdx.x;
const int32 total_thread_count = gridDim.x * blockDim.x;
@ -60,7 +60,7 @@ __global__ void CheckNumericsKernel(const T *data, int size,
// abnormality in the given array
template <typename T>
struct CheckNumericsLaunch {
void Run(const GPUDevice &d, const T *data, int size,
void Run(const GPUDevice& d, const T* data, int size,
int abnormal_detected[2]) {
const int32 block_size = d.maxGpuThreadsPerBlock();
const int32 num_blocks =

View File

@ -95,7 +95,7 @@ __global__ void DeterminantFromPivotedLUKernel(int nthreads, int n,
// since this cheap O(n) kernel always follows an O(n^3) LU factorization.
// The main purpose is to avoid having to copy the LU decomposition to
// host memory.
CUDA_1D_KERNEL_LOOP(o_idx, nthreads) {
GPU_1D_KERNEL_LOOP(o_idx, nthreads) {
// Initialize sign to (-1)^order.
const int order = PermutationOrder(n, all_pivots + o_idx * n);
Scalar prod_sign = order % 2 ? Scalar(-1) : Scalar(1);

View File

@ -274,7 +274,8 @@ namespace internal {
template <typename T>
struct AvgPoolMeanReducer {
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) && \
!defined(__HIPCC__)
// We only support packet access for floats.
static const bool PacketAccess = internal::is_same<T, float>::value;
#else
@ -303,7 +304,8 @@ struct AvgPoolMeanReducer {
return accum / T(scalarCount_);
}
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) && \
!defined(__HIPCC__)
#ifdef EIGEN_VECTORIZE_AVX512
#define pequal(a, b) \
_mm512_castsi512_ps( \
@ -370,7 +372,8 @@ template <typename Device>
struct reducer_traits<AvgPoolMeanReducer<float>, Device> {
enum {
Cost = 1,
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__)
#if (EIGEN_ARCH_i386 || EIGEN_ARCH_x86_64) && !defined(__CUDACC__) && \
!defined(__HIPCC__)
// We only support packet access for floats.
PacketAccess = true,
#else

View File

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
#ifndef TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_
#define TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
@ -118,4 +118,4 @@ class GpuDeviceArrayOnHost {
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_H_
#endif // TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_H_

View File

@ -15,8 +15,8 @@ limitations under the License.
// Contains structs and functions to be included in device code.
#ifndef TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
#define TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
#ifndef TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_
#define TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_
#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
(defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
@ -48,4 +48,4 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ValueType* GetGpuDeviceArrayOnDevice(
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_KERNELS_CUDA_DEVICE_ARRAY_GPU_H_
#endif // TENSORFLOW_CORE_KERNELS_GPU_DEVICE_ARRAY_GPU_H_

View File

@ -66,7 +66,7 @@ __global__ void ComputePermutationFromTranspositionsKernel(
// We only parallelize over batches here. Performance is not critical,
// since this cheap O(num_rows) kernel always follows an O(num_rows^3)
// LU factorization.
CUDA_1D_KERNEL_LOOP(index, config.virtual_thread_count) {
GPU_1D_KERNEL_LOOP(index, config.virtual_thread_count) {
ComputePermutationFromTranspositions(
num_rows, all_pivots + index * num_rows,
all_permutation_indices + index * num_rows);

View File

@ -63,10 +63,10 @@ __global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
int64 ldu, const Scalar* M,
const Scalar* U, const Scalar* S,
Scalar* V) {
CUDA_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
CUDA_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
GPU_AXIS_KERNEL_LOOP(batch, config.virtual_thread_count.x, X) {
GPU_AXIS_KERNEL_LOOP(i, config.virtual_thread_count.y, Y) {
Scalar v = M[i + m * batch] * U[ldu * (i + m * batch)] * S[batch];
CudaAtomicAdd(V + batch, v);
GpuAtomicAdd(V + batch, v);
}
}
}
@ -75,7 +75,7 @@ __global__ void ComputeValueOfVKernel(Gpu2DLaunchConfig config, int64 m,
// V[i] = V[i]>=0 ? 1 : 0
template <class Scalar>
__global__ void ExtractSignOfVKernel(GpuLaunchConfig config, Scalar* V) {
CUDA_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
GPU_1D_KERNEL_LOOP(i, config.virtual_thread_count) {
V[i] = V[i] >= 0 ? Scalar(1) : Scalar(-1);
}
}
@ -195,7 +195,7 @@ class SvdOpGpu : public AsyncOpKernel {
// 1. compute the (batched) sum
const GPUDevice& d = context->eigen_device<GPUDevice>();
d.memset(outputV_ptr, 0, batch_size * sizeof(Scalar));
Gpu2DLaunchConfig cfg2D = GetCuda2DLaunchConfig(batch_size, m, d);
Gpu2DLaunchConfig cfg2D = GetGpu2DLaunchConfig(batch_size, m, d);
TF_CHECK_OK(GpuLaunchKernel(ComputeValueOfVKernel<Scalar>,
cfg2D.block_count, cfg2D.thread_per_block, 0,
d.stream(), cfg2D, m, full_matrices_ ? m : p,

View File

@ -77,7 +77,7 @@ class TridiagonalMatMulOpGpu : public OpKernel {
OP_REQUIRES_OK(context, context->allocate_output(0, rhs.shape(), &output));
const Eigen::GpuDevice& device = context->eigen_device<Eigen::GpuDevice>();
CudaLaunchConfig cfg = GetGpuLaunchConfig(1, device);
GpuLaunchConfig cfg = GetGpuLaunchConfig(1, device);
TF_CHECK_OK(GpuLaunchKernel(
TridiagonalMatMulKernel<Scalar>, cfg.block_count, cfg.thread_per_block,
0, device.stream(), batch_size, m, n, superdiag.flat<Scalar>().data(),

View File

@ -48,7 +48,7 @@ __global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* diags,
*not_invertible = true;
return;
}
for (int i : CudaGridRangeX(num_rhs)) {
for (int i : GpuGridRangeX(num_rhs)) {
x[i] = rhs[i] / diags[1];
}
} else {
@ -57,7 +57,7 @@ __global__ void SolveForSizeOneOrTwoKernel(const int m, const Scalar* diags,
*not_invertible = true;
return;
}
for (int i : CudaGridRangeX(num_rhs)) {
for (int i : GpuGridRangeX(num_rhs)) {
x[i] = (diags[3] * rhs[i] - diags[0] * rhs[i + num_rhs]) / det;
x[i + num_rhs] = (diags[2] * rhs[i + num_rhs] - diags[5] * rhs[i]) / det;
}